A pair of pure functions implementing a gradient transformation.
Prefer GradientTransformationExtraArgs
for new optimizers.
Optax optimizers are all implemented as gradient transformations. A gradient transformation is defined to be a pair of pure functions, which are combined together in a NamedTuple so that they can be referred to by name.
Note that an extended API is provided for users wishing to build optimizers that take additional arguments during the update step. For more details, see optax.GradientTransformationExtraArgs()
.
Since gradient transformations do not contain any internal state, all stateful optimizer properties (such as the current step count when using optimizer schedules or momentum values) are passed through optax gradient transformations by using the optimizer state pytree. Each time a gradient transformation is applied, a new state is computed and returned, ready to be passed to the next call to the gradient transformation.
Since gradient transformations are pure functions, the only way to change the behavior of a gradient transformation between steps, is to change the values in the optimizer state. To see an example of mutating the optimizer state in order to control the behavior of an optax gradient transformation see the meta-learning example in the optax documentation.
A pure function which, when called with an example instance of the parameters whose gradients will be transformed, returns a pytree containing the initial value for the optimizer state.
A pure function which takes as input a pytree of updates (with the same tree structure as the original params pytree passed to init), the previous optimizer state (which may have been initialized using the init function), and optionally the current params. The update function then returns the computed gradient updates, and a new optimizer state.
A specialization of GradientTransformation that supports extra args.
Extends the existing GradientTransformation interface by adding support for passing extra arguments to the update function.
Note that if no extra args are provided, then the API of this function is identical to the case of TransformUpdateFn
. This means that we can safely wrap any gradient transformation (that does not support extra args) as one that does. The new gradient transformation will accept (and ignore) any extra arguments that a user might pass to it. This is the behavior implemented by optax.with_extra_args_support()
.
Overrides the type signature of the update in the base type to accept extra arguments.
optax._src.base.TransformUpdateExtraArgsFn
A callable type for the init step of a GradientTransformation.
The init step takes a tree of params and uses these to construct an arbitrary structured initial state for the gradient transformation. This may hold statistics of the past updates or any other non static information.
A callable type for the update step of a GradientTransformation.
The update step takes a tree of candidate parameter updates (e.g. their gradient with respect to some loss), an arbitrary structured state, and the current params of the model being optimized. The params argument is optional, it must however be provided when using transformations that require access to the current values of the parameters.
For the case where additional arguments are required, an alternative interface may be used, see TransformUpdateExtraArgsFn
for details.
alias of Array
| ndarray
| bool
| number
| Iterable
[ArrayTree
] | Mapping
[Any
, ArrayTree
]
alias of Array
| ndarray
| bool
| number
| Iterable
[ArrayTree
] | Mapping
[Any
, ArrayTree
]
alias of Array
| ndarray
| bool
| number
| Iterable
[ArrayTree
] | Mapping
[Any
, ArrayTree
]
Clips updates to be at most clipping * parameter_norm
, unit-wise.
clipping – The maximum allowed ratio of update norm to parameter norm.
eps – An epsilon term to prevent clipping of zero-initialized params.
axis – Axis or axes along which to compute the unit-wise norm. If None, uses default behavior based on input dimensions. This is useful for custom parameter shapes like Conv3D (ndim=5).
A optax.GradientTransformation
object.
References
Brock et al., High-Performance Large-Scale Image Recognition Without Normalization, 2021
alias of EmptyState
Add parameter scaled by weight_decay.
weight_decay – A scalar weight decay rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
A optax.GradientTransformation
object.
alias of EmptyState
Add gradient noise.
eta – Base variance of the gaussian noise added to the gradient.
gamma – Decay exponent for annealing of the variance.
key – random generator key for noise generation.
seed – deprecated, use key instead.
A optax.GradientTransformation
object.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> key = jax.random.key(0) # could also be key=0 >>> noise = optax.add_noise(eta=0.01, gamma=0.55, key=key) >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) >>> solver = optax.chain(noise, sgd) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01
References
Neelakantan et al, Adding Gradient Noise Improves Learning for Very Deep Networks, 2015
State for adding gradient noise. Contains a count for annealing.
Accumulate gradients and apply them every k steps.
Note that if this transformation is part of a chain, the states of the other transformations will still be updated at every step. In particular, using apply_every with a batch size of N/2 and k=2 is not necessarily equivalent to not using apply_every with a batch size of N. If this equivalence is important for you, consider using the optax.MultiSteps.
k – Emit non-zero gradients every k steps, otherwise accumulate them.
A optax.GradientTransformation
object.
Contains a counter and a gradient accumulator.
Performs bias correction. It becomes a no-op as count goes to infinity.
Centralizes gradients by subtracting their mean along leading dimension.
A optax.GradientTransformation
object.
Example
>>> import jax.numpy as jnp >>> import optax >>> grad = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> opt = optax.centralize() >>> state = opt.init(grad) >>> updates, state = opt.update(grad, state) >>> print(updates) [[-1. 0. 1.] [-1. 0. 1.]] >>> print(state) EmptyState()
References
Yong et al, Gradient Centralization: A New Optimization Technique for Deep Neural Networks, 2020.
Calls the inner update function only at certain steps.
Creates a transformation wrapper that conditionally applies the inner gradient transformation, and if the condition is not met, the updates are set to 0, while the inner state is passed through unchanged. The behavior is controlled by a user specified function should_transform_fn
that is called by conditionally_transform
passing as input a counter of the number of times that the update
function has been previously called, the user specified function must returns a boolean controlling whether the inner transformation should be called.
inner – the inner transformation.
should_transform_fn – function takes in a step counter (array of shape [] and dtype int32
), and returns a boolean array of shape []. If forward_extra_args
is set to True, any extra arguments are also forwarded to the should_transform_fn
.
forward_extra_args – forward extra args to should_transform_fn
.
Warning
If instead you want to leave updates
unchanged when the condition is not met, you can use the conditionally_transform
wrapper.
Added in version 0.2.3.
Calls the inner update function only at certain steps.
Creates a transformation wrapper that conditionally applies the inner gradient transformation, and if the condition is not met, just passes the updates and inner state through unchanged. The behavior is controlled by a user specified function should_transform_fn
that is called by conditionally_transform
passing as input a counter of the number of times that the update
function has been previously called, the user specified function must returns a boolean controlling whether the inner transformation should be called.
inner – the inner transformation.
should_transform_fn – function takes in a step
counter (array of shape [] and dtype int32
), and returns a boolean array of shape []. If forward_extra_args
is set to True, any extra arguments are also forwarded to the should_transform_fn
.
forward_extra_args – forward extra args to should_transform_fn
.
Warning
If instead you want to set the updates
to zero when the condition is not met, you can use the conditionally_mask
wrapper.
Added in version 0.2.3.
Maintains inner transform state and adds a step counter.
Clips updates element-wise, to be in [-max_delta, +max_delta]
.
max_delta – The maximum absolute value for each element in the update.
A optax.GradientTransformation
object.
Clips updates to a max rms for the gradient of each param vector or matrix.
A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
threshold – The maximum rms for the gradient of each param vector or matrix.
A optax.GradientTransformation
object.
alias of EmptyState
Clips updates using their global norm.
max_norm – The maximum global norm for an update.
A optax.GradientTransformation
object.
References
Pascanu et al., On the difficulty of training Recurrent Neural Networks, 2012
alias of EmptyState
Compute an exponential moving average of past updates.
decay – Decay rate for the exponential moving average.
debias – Whether to debias the transformed gradient.
accumulator_dtype – Optional dtype to used for the accumulator; if None then the dtype is inferred from params and updates.
A optax.GradientTransformation
object.
Note
optax.trace()
and optax.ema()
have very similar but distinct updates; trace = decay * trace + t
, while ema = decay * ema + (1-decay) * t
. Both are frequently found in the optimization literature.
Holds an exponential moving average of past updates.
An empty state for the simplest stateless transformations.
Compute the global norm across a nested structure of tensors.
Stateless identity transformation that leaves input gradients untouched.
This function passes through the gradient updates unchanged.
Note, this should not to be confused with set_to_zero, which maps the input updates to zero - which is the transform required for the model parameters to be left unchanged when the updates are applied to them.
A optax.GradientTransformation
object.
Modifies the updates to keep parameters non-negative, i.e. >= 0.
This transformation ensures that parameters after the update will be larger than or equal to zero. In a chain of transformations, this should be the last one.
A optax.GradientTransformation
object.
Warning
The transformation expects input params to be non-negative. When params is negative the transformed update will move them to 0.
alias of EmptyState
Mask updates so only some are transformed, the rest are passed through.
For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. Since in many networks, these are the only 1D parameters, you may for instance create a mask function to mask them out as follows:
mask_fn = lambda p: jax.tree.map(lambda x: x.ndim != 1, p) weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn)
You may alternatively create the mask pytree upfront:
mask = jax.tree.map(lambda x: x.ndim != 1, params) weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)
For the inner
transform, state will only be stored for the parameters that have a mask value of True
.
Note that, when using tree_map_params
, it may be required to pass the argument is_leaf=lambda v: isinstance(v, optax.MaskedNode), if the tree map needs to take additional arguments with the same shape as the original input tree.
inner – Inner transformation to mask.
mask – a PyTree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True
for leaves/subtrees you want to apply the transformation to, and False
for those you want to skip. The mask must be static for the gradient transformation to be jit-compilable.
mask_compatible_extra_args – whether to also apply the same masking to extra_arg fields with the same tree structure as params/updates.
New optax.GradientTransformationExtraArgs
wrapping inner
.
Scale by the inverse of the update norm.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.normalize_by_update_norm(scale_factor=-1.0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function:', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 7.52E+00 Objective function: 3.03E+00 Objective function: 5.50E-01 Objective function: 6.67E-02 Objective function: 5.50E-01
scale_factor – factor by which the update will be multiplied (defaults to 1).
eps – jitter term to avoid dividing by 0
A optax.GradientTransformation
object.
Applies gradient clipping per-example using their global norm.
grads – flattened update; the function expects each array in this list to have a batch dimension on the 0th axis.
l2_norm_clip – maximum L2 norm of the per-example gradients.
A tuple containing sum of the clipped per-example grads, and the number of per-example grads that were clipped.
Example
>>> import optax >>> import jax.numpy as jnp >>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])] >>> optax.per_example_global_norm_clip(grads, jnp.inf) ([Array([7., 7., 7.], dtype=float32)], Array(0, dtype=int32)) >>> optax.per_example_global_norm_clip(grads, 0.0) ([Array([0., 0., 0.], dtype=float32)], Array(3, dtype=int32)) >>> optax.per_example_global_norm_clip(grads, 1.25) ([Array([1.75, 1.75, 1.75], dtype=float32)], Array(3, dtype=int32))
References
Abadi et al., Deep Learning with Differential Privacy, 2016
Applies gradient clipping per-example using per-layer norms.
If len(grads) == 1, this function is equivalent to optax.per_example_global_norm_clip. If len(grads) > 1, each array in grads will be independently clipped to a value C_i
documented below.
Let C = global_l2_norm_clip value
. Then per-layer clipping is done as follows:
1. If uniform
is True
, each of the K
layers has an individual clip norm of C / sqrt(K)
.
2. If uniform
is False
, each of the K
layers has an individual clip norm of C * sqrt(D_i / D)
where D_i
is the number of parameters in layer i
, and D
is the total number of parameters in the model.
grads – flattened update; i.e. a list of gradients in which each item is the gradient for one layer; the function expects these to have a batch dimension on the 0th axis.
global_l2_norm_clip – overall L2 clip norm to use.
uniform – If True, per-layer clip norm is global_l2_norm_clip/sqrt(L)
, where L
is the number of layers. Otherwise, per-layer clip norm is global_l2_norm_clip * sqrt(f)
, where f
is the fraction of total model parameters that are in this layer.
A tuple containing sum of the clipped per-example grads and the number of per-example grads that were clipped for each layer.
Example
>>> import optax >>> import jax.numpy as jnp >>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])] >>> optax.per_example_layer_norm_clip(grads, jnp.inf) ([Array([7., 7., 7.], dtype=float32)], [Array(0, dtype=int32)]) >>> optax.per_example_layer_norm_clip(grads, 0.0) ([Array([0., 0., 0.], dtype=float32)], [Array(3, dtype=int32)]) >>> optax.per_example_layer_norm_clip(grads, 1.25) ([Array([1.75, 1.75, 1.75], dtype=float32)], [Array(3, dtype=int32)])
References
McMahan et al., Learning Differentially Private Recurrent Language Models, 2017
Scale updates by some fixed scalar step_size.
step_size – A scalar corresponding to a fixed scaling factor for updates.
A optax.GradientTransformation
object.
alias of EmptyState
Rescale updates according to the Adadelta algorithm.
See optax.adadelta()
for more details.
rho – A coefficient used for computing a running average of squared gradients.
eps – Term added to the denominator to improve numerical stability.
A optax.GradientTransformation
object.
State for the rescaling by Adadelta algorithm.
Rescale updates according to the Adan algorithm.
See optax.adan()
for more details.
b1 – Decay rate for the EWMA of gradients.
b2 – Decay rate for the EWMA of differences of gradients.
b3 – Decay rate for the EMWA of the algorithm’s squared term.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
An optax.GradientTransformation
object.
Rescale updates according to the Adam algorithm.
See optax.adam()
for more details.
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of squared grads.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
nesterov – Whether to use Nesterov momentum. The variant of Adam with Nesterov momentum is described in [Dozat 2016]
A optax.GradientTransformation
object.
Rescale updates according to the Adamax algorithm.
See optax.adamax()
for more details.
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted maximum of grads.
eps – Term added to the denominator to improve numerical stability.
A optax.GradientTransformation
object.
State for the Adam algorithm.
Rescale updates according to the AMSGrad algorithm.
See optax.amsgrad()
for more details.
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of squared grads.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
A optax.GradientTransformation
object.
State for the AMSGrad algorithm.
Backtracking line-search ensuring sufficient decrease (Armijo criterion).
Selects learning rate \(\eta\) such that it verifies the sufficient decrease criterion
\[f(w + \eta u) \leq (1+\delta)f(w) + \eta c \langle u, \nabla f(w) \rangle + \epsilon \,, \]
where
\(f\) is the function to minimize, \(w\) are the current parameters, \(\eta\) is the learning rate to find, \(u\) is the update direction, \(c\) is a coefficient (
slope_rtol
) measuring the relative decrease of the function in terms of the slope (scalar product between the gradient and the updates), \(\delta\) is a relative tolerance (rtol
), \(\epsilon\) is an absolute tolerance (atol
).
The algorithm starts with a given guess of a learning rate and decrease it by decrease_factor
until the criterion above is met.
max_backtracking_steps – maximum number of iterations for the line-search.
slope_rtol – relative tolerance w.r.t. to the slope. The sufficient decrease must be slope_rtol * lr * <grad, updates>, see formula above.
decrease_factor – decreasing factor to reduce learning rate.
increase_factor – increasing factor to increase learning rate guess. Setting it to 1. amounts to keep the current guess, setting it to math.inf
amounts to start with max_learning_rate
at each round.
max_learning_rate – maximum learning rate (learning rate guess clipped to this).
atol – absolute tolerance at which the criterion needs to be satisfied.
rtol – relative tolerance at which the criterion needs to be satisfied.
store_grad – whether to compute and store the gradient at the end of the linesearch. Since the function is called to compute the value to accept the learning rate, we can also access the gradient along the way. By doing that, we can directly reuse the value and the gradient computed at the end of the linesearch for the next iteration using optax.value_and_grad_from_state()
. See the example above.
verbose – whether to print debugging information.
A GradientTransformationExtraArgs
, where the update
function takes the following additional keyword arguments:
value
: value of the function at the current params.
grad
: gradient of the function at the current params.
value_fn
: function returning the value of the function we seek to optimize.
**extra_args
: additional keyword arguments, if the function needs additional arguments such as input data, they should be put there ( see example in this docstring).
Examples
An example on using the backtracking line-search with SGD:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch(max_backtracking_steps=15) ... ) >>> # Function with additional inputs other than params >>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y) >>> params = jnp.array([1., 2., 3.]) >>> opt_state = solver.init(params) >>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.) >>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,)) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 5.00E+01 >>> for x, y in zip(xs, ys): ... value, grad = jax.value_and_grad(fn)(params, x, y) ... updates, opt_state = solver.update( ... grad, ... opt_state, ... params, ... value=value, ... grad=grad, ... value_fn=fn, ... x=x, ... y=y ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 3.86E+01 Objective function: 2.50E+01 Objective function: 1.34E+01 Objective function: 5.87E+00 Objective function: 5.81E+00
A similar example, but with a non-stochastic function where we can reuse the value and the gradient computed at the end of the linesearch:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> # Function without extra arguments >>> def fn(params): return jnp.sum(params ** 2) >>> params = jnp.array([1., 2., 3.]) >>> # In this case we can store value and grad with the store_grad field >>> # and reuse them using optax.value_and_grad_state_from_state >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_backtracking_linesearch( ... max_backtracking_steps=15, store_grad=True ... ) ... ) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 5.04E+00 Objective function: 1.81E+00 Objective function: 6.53E-01 Objective function: 2.35E-01 Objective function: 8.47E-02
References
Vaswani et al., Painless Stochastic Gradient, 2019
Nocedal & Wright, Numerical Optimization, 1999
Warning
The sufficient decrease criterion might be impossible to satisfy for some update directions. To guarantee a non-trivial solution for the sufficient decrease criterion, a descent direction for updates (\(u\)) is required. An update (\(u\)) is considered a descent direction if the derivative of \(f(w + \eta u)\) at \(\eta = 0\) (i.e., \(\langle u, \nabla f(w)\rangle\)) is negative. This condition is automatically satisfied when using optax.sgd()
(without momentum), but may not hold true for other optimizers like optax.adam()
.
More generally, when chained with other transforms as optax.chain(opt_1, ..., opt_k, scale_by_backtraking_linesearch(max_backtracking_steps=...), opt_kplusone, ..., opt_n)
, the updates returned by chaining opt_1, ..., opt_k
must be a descent direction. However, any transform after the backtracking line-search doesn’t necessarily need to satisfy the descent direction property (one could for example use momentum).
Note
The algorithm can support complex inputs.
Added in version 0.2.0.
State for optax.scale_by_backtracking_linesearch()
.
learning rate computed at the end of a round of line-search, used to scale the update.
Union[float, jax.Array]
value of the objective computed at the end of a round of line-search. Can be reused using optax.value_and_grad_from_state()
.
Union[float, jax.Array]
gradient of the objective computed at the end of a round of line-search if the line-search is instantiated with store_grad = True. Otherwise it is None. Can be reused using optax.value_and_grad_from_state()
.
Optional[base.Updates]
information about the backtracking linesearch step, for debugging.
BacktrackingLinesearchInfo
Rescale updates according to the AdaBelief algorithm.
See optax.adabelief()
for more details.
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of variance of grads.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.
nesterov – Whether to use Nesterov momentum.
A optax.GradientTransformation
object.
State for the rescaling by AdaBelief algorithm.
Scaling by a factored estimate of the gradient rms (as in Adafactor).
This is a so-called “1+epsilon” scaling algorithms, that is extremely memory efficient compared to RMSProp/Adam, and has had wide success when applied to large-scale training of attention-based models.
factored – boolean: whether to use factored second-moment estimates..
decay_rate – float: controls second-moment exponential decay schedule.
step_offset – for finetuning, one may set this to the starting step-number of the fine tuning phase.
min_dim_size_to_factor – only factor accumulator if two array dimensions are at least this size.
epsilon – Regularization constant for squared gradient.
decay_rate_fn – A function that accepts the current step, the decay rate parameter and controls the schedule for the second momentum. Defaults to the original adafactor’s power decay schedule. One potential shortcoming of the original schedule is the fact that second momentum converges to 1, which effectively freezes the second momentum. To prevent this the user can opt for a custom schedule that sets an upper bound for the second momentum, like in Zhai et al., 2021.
The corresponding optax.GradientTransformation
.
References
Shazeer et al, Adafactor: Adaptive Learning Rates with Sublinear Memory Cost, 2018
Zhai et al, Scaling Vision Transformers, 2021
Overall state of the gradient transformation.
Scales updates by L-BFGS.
L-BFGS is a quasi-Newton method that multiplies the update (gradient) with an approximation of the inverse Hessian. This algorithm does not need access to the Hessian, as this approximation is constructed from the gradient evaluations seen during optimization. L-BFGS is a limited-memory variant of the Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm. The BFGS algorithm requires storing a matrix of size \(p \times p\) with \(p\) the dimension of the parameters. The limited variant circuments this issue by computing the approximation of the inverse using only \(m\) (memory_size
) past differences of parameters/gradients. Namely, the approximation of the Hessian inverse is denoted \(P_k = P_{k, k}\), where
\[\begin{align*} P_{k, j+1} & = V_j^\top P_{k, j} V_j + \rho_j \delta w_j \delta w_j^\top \quad \text{for} \ j \in \{k-m, \ldots, k-1\}\\ P_{k, k-m} & = \gamma_k I \\ V_k & = I - \rho_k \delta u_k \delta w_k^\top \\ \rho_k & = 1/(\delta u_k^\top \delta w_k) \\ \delta w_k & = w_{k+1} - w_k \\ \delta u_k & = u_{k+1} - u_k \\ \gamma_k & = \begin{cases} (\delta w_{k-1}^\top \delta u_{k-1}) / (\delta u_{k-1}^\top \delta u_{k-1}) & \text{if} \ \texttt{scale\_init\_hess} \\ 1 & \text{otherwise} \end{cases}, \end{align*}\]
for \(u_k\) the gradients/updates at iteration \(k\), \(w_k\) the parameters at iteration \(k\).
The formula for updating \(P_k\) is obtained by computing the optimal preconditioning matrix subject to some secant condition, see references for more details. Computing \(P_k u_k\) can be done by a sequence of vector operations using past differences of parameters and gradients stored in a memory bufffer.
The present function just outputs the LBFGS direction \(P_k u_k\). It can be chained with a linesearch ensuring sufficient decrease and low curvature, such as a zoom linesearch. The linesearch computes a stepsize \(\eta_k\), such that the updated parameters (using optax.apply_updates()
) take the form \(w_{k+1} = w_k - \eta_k P_k u_k\).
memory_size – number of past parameters, gradients/updates to keep in memory to approximate the Hessian inverse.
scale_init_precond – whether to use a scaled identity as the initial preconditioner, see formula of \(\gamma_k\) above.
A optax.GradientTransformation
object.
References
Algorithms 7.4, 7.5 (page 199) of Nocedal et al, Numerical Optimization , 1999
Liu et al., On the limited memory BFGS method for large scale optimization , 1989.
Note
We initialize the scaling of the identity as a capped reciprocal of the gradient norm. This avoids wasting linesearch iterations for the first step by taking into account the magnitude of the gradients. In other words, we constrain the trust-region of the first step to an Euclidean ball of radius 1 at the first iteration. The choice of \(\gamma_0\) is not detailed in the references above, so this is a heuristic choice.
State for LBFGS solver.
iteration of the algorithm.
chex.Numeric
current parameters.
base.Params
current updates.
base.Params
represents a list of past parameters’ differences up to some predetermined memory_size
fixed in optax.scale_by_lbfgs()
.
chex.ArrayTree
represents a list of past gradients/updates’ differences up to some predetermined memory_size
fixed in optax.scale_by_lbfgs()
.
chex.ArrayTree
list of past weights multiplying the rank one matrices defining the inverse Hessian approximation, see optax.scale_by_lbfgs()
for more details.
chex.Array
Scale by the (negative) learning rate (either as scalar or as schedule).
learning_rate – Can either be a scalar or a schedule (i.e. a callable that maps an (int) step to a float). None means no scaling.
flip_sign – When set to True (the default) this corresponds to scaling by the negative learning rate.
An optax.GradientTransformation that corresponds to multiplying the gradient with -learning_rate (if flip_sign is True) or with learning_rate (if flip_sign is False).
Rescale updates according to the Lion algorithm.
See optax.lion()
for more details.
b1 – Rate for combining the momentum and the current grad.
b2 – Decay rate for the exponentially weighted average of grads.
mu_dtype – Optional dtype to be used for the momentum; if None then the dtype is inferred from `params and updates.
A optax.GradientTransformation
object.
State for the Lion algorithm.
Computes NovoGrad updates.
See optax.novograd()
for more details.
b1 – A decay rate for the exponentially weighted average of grads.
b2 – A decay rate for the exponentially weighted average of squared grads.
eps – A term added to the denominator to improve numerical stability.
eps_root – A term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
weight_decay – A scalar weight decay rate.
mu_dtype – An optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
The corresponding optax.GradientTransformation
.
State for Novograd.
Compute generalized optimistic gradients.
See optax.optimistic_adam()
, optax.optimistic_gradient_descent()
for more details.
alpha – Coefficient for generalized optimistic gradient descent.
beta – Coefficient for negative momentum.
A optax.GradientTransformation
object.
Scale updates for each param block by the norm of that block’s parameters.
A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
min_scale – Minimum scaling factor.
A optax.GradientTransformation
object.
Scale updates by rms of the gradient for each param vector or matrix.
A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
min_scale – Minimum scaling factor.
A optax.GradientTransformation
object.
Rescale updates according to the Rectified Adam algorithm.
See optax.radam()
for more details.
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of squared grads.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
threshold – Threshold for variance tractability.
nesterov – Whether to use Nesterov momentum.
A optax.GradientTransformation
object.
Scales the update by Polyak’s step-size.
See optax.polyak_sgd()
for more details.
f_min – a lower bound on the objective function (defaults to 0). Corresponds to \(f^\star\) in the formula above.
max_learning_rate – a maximum step size to use (defaults to 1).
eps – a value to add in the denominator of the update (defaults to 0).
variant – either 'sps'
or 'sps+'
(defaults to 'sps'
).
A optax.GradientTransformationExtraArgs
, where the update
function takes an additional keyword argument value
containing the current value of the objective function.
Rescale updates by the root of the exp. moving avg of the square.
See optax.rmsprop()
for more details.
decay – Decay rate for the exponentially weighted average of squared grads.
eps – Term added to the denominator to improve numerical stability.
initial_scale – Initial value for second moment.
eps_in_sqrt – Whether to add eps
in the square root of the denominator or outside the square root.
bias_correction – Whether to apply bias correction to the exponentially weighted average of squared grads.
A optax.GradientTransformation
object.
Note
Using scale_by_rms(decay=b2, eps_in_sqrt=False, bias_correction=True) will match the behavior of scale_by_adam(b1=0, b2=b2), while sparing the memory cost of storing the first moment.
State for exponential root mean-squared (RMS)-normalized updates.
Scale with the Rprop optimizer.
See optax.rprop()
for more details.
learning_rate – The initial step size.
eta_minus – Multiplicative factor for decreasing step size. This is applied when the gradient changes sign from one step to the next.
eta_plus – Multiplicative factor for increasing step size. This is applied when the gradient has the same sign from one step to the next.
min_step_size – Minimum allowed step size. Smaller steps will be clipped to this value.
max_step_size – Maximum allowed step size. Larger steps will be clipped to this value.
The corresponding optax.GradientTransformation
.
Rescale updates by the root of the sum of all squared gradients to date.
See optax.adagrad()
for more details.
initial_accumulator_value – Starting value for accumulators, must be >= 0.
eps – A small floating point value to avoid zero denominator.
A optax.GradientTransformation
object.
State holding the sum of gradient squares to date.
Scale updates using a custom schedule for the step_size.
step_size_fn – A function that takes an update count as input and proposes the step_size to multiply the updates by.
A optax.GradientTransformation
object.
Maintains count for scale scheduling.
Compute the signs of the gradient elements.
An optax.GradientTransformation that contains the signs of the input gradient.
Scale updates by sm3.
See optax.sm3()
for more details.
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of squared grads.
eps – Term added to the denominator to improve numerical stability.
A optax.GradientTransformation
object.
State for the SM3 algorithm.
Rescale updates by the root of the centered exp. moving average of squares.
See optax.rmsprop()
for more details.
decay – Decay rate for the exponentially weighted average of squared grads.
eps – Term added to the denominator to improve numerical stability.
initial_scale – Initial value for second moment.
eps_in_sqrt – Whether to add eps
in the square root of the denominator or outside the square root.
bias_correction – Whether to apply bias correction to the first and second moment.
A optax.GradientTransformation
object.
State for centered exponential moving average of squares of updates.
Scale updates by trust ratio.
Used in optax.fromage()
, optax.lars()
, optax.lamb()
.
min_norm – Minimum norm for params and gradient norms; by default is zero.
trust_coefficient – A multiplier for the trust ratio.
eps – Additive constant added to the denominator for numerical stability.
A optax.GradientTransformation
object.
alias of EmptyState
Rescale updates according to the Yogi algorithm.
See optax.yogi()
for more details.
Supports complex numbers, see https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of variance of grads.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
initial_accumulator_value – The starting value for accumulators. Only positive values are allowed.
A optax.GradientTransformation
object.
Linesearch ensuring sufficient decrease and small curvature.
This algorithm searches for a learning rate, a.k.a. stepsize, that satisfies both a sufficient decrease criterion, a.k.a. Armijo-Goldstein criterion,
\[f(w + \eta u) \leq f(w) + \eta c_1 \langle u, \nabla f(w) \rangle + \epsilon \,, \]
and a small curvature (along the update direction) criterion, a.k.a. Wolfe or second Wolfe criterion,
\[|\langle \nabla f(w + \eta u), u \rangle| \leq c_2 |\langle \nabla f(w), \rangle| + \epsilon\,, \]
where
\(f\) is the function to minimize,
\(w\) are the current parameters,
\(\eta\) is the learning rate to find,
\(u\) is the update direction,
\(c_1\) is a coefficient (slope_rtol
) measuring the relative decrease of the function in terms of the slope (scalar product between the gradient and the updates),
\(c_2\) is a coefficient (curv_rtol
) measuring the relative decrease of curvature.
\(\epsilon\) is an absolute tolerance (tol
).
To deal with very flat functions, this linesearch switches from the sufficient decrease criterion presented above to an approximate sufficient decrease criterion introduced by Hager and Zhang (see [Hager and Zhang, 2006]).
\[|\langle \nabla f(w+\eta u), u \rangle| \leq (2 c_1 - 1) |\langle \nabla f(w), \rangle| + \epsilon\,. \]
The approximate curvature criterion is taken only if the values tried by the linesearch fall below a relative decrease of the initial function, that is,
\[f(w + \eta u) \leq f(w) + c_3 |f(w)| \]
where \(c_3\) is a coefficient approx_dec_rtol
measuring the relative decrease of the objective (see reference below and comments in the code for more details).
The original sufficient decrease criterion can only capture differences up to \(\sqrt{\varepsilon_{machine}}\) while the approximate sufficient decrease criterion can capture differences up to \(\varepsilon_{machine}\) (see [Hager and Zhang, 2006]). Note that this add-on is not part of the original implementation of [Nocedal and Wright, 1999] and can be removed by setting approx_dec_rtol
to None
.
max_linesearch_steps – maximum number of linesearch iterations.
max_learning_rate – maximum admissible learning rate. Can be set to None
for no upper bound. A non None
value may prevent the linesearch to find a learning rate satisfying the small curvature criterion, since the latter may require sufficiently large stepsizes.
tol – tolerance on the criterions.
increase_factor – increasing factor to augment the learning rate when searching for a valid interval containing a learning rate satisfying both criterions.
slope_rtol – relative tolerance for the slope in the sufficient decrease criterion.
curv_rtol – relative tolerance for the curvature in the small curvature criterion.
approx_dec_rtol – relative tolerance for the initial value in the approximate sufficient decrease criterion. Can be set to None
to use only the original Armijo-Goldstein decrease criterion.
stepsize_precision – precision in the search of a stepsize satisfying both conditions. The algorithm proceeds with a bisection that refines an interval containing a stepsize satisfying both conditions. If that interval is reduced below stepsize_precision
and a stepsize satisfying a sufficient decrease has been found, the algorithm selects that stepsize even if the curvature condition is not satisfied.
initial_guess_strategy – initial guess for the learning rate used to start the linesearch. Can be either one
or keep
. If one
, the initial guess is set to 1. If keep
, the initial guess is set to the learning rate of the previous step. We recommend to use keep
if this linesearch is used in combination with SGD. We recommend to use one
if this linesearch is used in combination with Newton methods or quasi-Newton methods such as L-BFGS.
verbose – whether to print additional debugging information in case the linesearch fails.
A optax.GradientTransformationExtraArgs
object consisting in an init and an update function.
Examples
An example on using the zoom line-search with SGD:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_zoom_linesearch(max_linesearch_steps=15) ... ) >>> # Function with additional inputs other than params >>> def fn(params, x, y): return optax.l2_loss(x.dot(params), y) >>> params = jnp.array([1., 2., 3.]) >>> opt_state = solver.init(params) >>> x, y = jnp.array([3., 2., 1.]), jnp.array(0.) >>> xs, ys = jnp.tile(x, (5, 1)), jnp.tile(y, (5,)) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 5.00E+01 >>> for x, y in zip(xs, ys): ... value, grad = jax.value_and_grad(fn)(params, x, y) ... updates, opt_state = solver.update( ... grad, ... opt_state, ... params, ... value=value, ... grad=grad, ... value_fn=fn, ... x=x, ... y=y ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params, x, y))) Objective function: 2.56E-13 Objective function: 2.84E-14 Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00
A similar example, but with a non-stochastic function where we can reuse the value and the gradient computed at the end of the linesearch:
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> # Function without extra arguments >>> def fn(params): return jnp.sum(params ** 2) >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.chain( ... optax.sgd(learning_rate=1.), ... optax.scale_by_zoom_linesearch(max_linesearch_steps=15) ... ) >>> opt_state = solver.init(params) >>> print('Objective function: {:.2E}'.format(fn(params))) Objective function: 1.40E+01 >>> value_and_grad = optax.value_and_grad_from_state(fn) >>> for _ in range(5): ... value, grad = value_and_grad(params, state=opt_state) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value, grad=grad, value_fn=fn ... ) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(fn(params))) Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00 Objective function: 0.00E+00
References
Algorithms 3.5 3.6 of Nocedal and Wright, Numerical Optimization, 1999
Hager and Zhang Algorithm 851: CG_DESCENT, a Conjugate Gradient Method with Guaranteed Descent, 2006
Note
The curvature criterion can be avoided by setting by setting curv_rtol=jnp.inf
. The resulting algorithm will amount to a backtracking linesearch where a point satisfying sufficient decrease is searched by minimizing a quadratic or cubic approximation of the objective. This can be sufficient in practice and avoids having the linesearch spend many iterations trying to satisfy the small curvature criterion.
Note
The algorithm can support complex inputs.
State for scale_by_zoom_linesearch.
learning rate computed at the end of a round of line-search, used to scale the update.
chex.Numeric
value of the objective computed at the end of a round of line-search. Can be reused using optax.value_and_grad_from_state()
.
chex.Numeric
gradient of the objective computed at the end of a round of line-search. Can be reused using optax.value_and_grad_from_state()
.
base.Updates
Additional information on the status of the linesearch see otpax.ZoomLinesearchInfo
.
Stateless transformation that maps input gradients to zero.
The resulting update function, when called, will return a tree of zeros matching the shape of the input gradients. This means that when the updates returned from this transformation are applied to the model parameters, the model parameters will remain unchanged.
This can be used in combination with partition or masked to freeze (i.e. keep fixed) some parts of the tree of model parameters while applying gradient updates to other parts of the tree.
When updates are set to zero inside the same jit-compiled function as the calculation of gradients, optax transformations, and application of updates to parameters, unnecessary computations will in general be dropped.
A optax.GradientTransformation
object.
Creates a stateless transformation from an update-like function.
This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations.
f – Update function that takes in updates (e.g. gradients) and parameters and returns updates. The parameters may be None.
Creates a stateless transformation from an update-like function for arrays.
This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations, just like optax.stateless. In addition, this function will apply the tree map over update/params for you.
f – Update function that takes in an update array (e.g. gradients) and parameter array and returns an update array. The parameter array may be None.
Takes a snapshot of updates and stores it in the state.
Useful to debug intermediate updates values in a chained transformation.
measure_name – Name of the measurement to store. Can be then used to retrieve the snapshot using optax.tree.get(state, measure_name).
measure – User callable taking as inputs updates and returning desired measurement. When this transformation is part of a chain, the updates are the transformed gradients up to that transform.
A gradient transformation that captures measurements defined by the user in the callable measure and stores them in the state with the name measure_name.
Examples
>>> import optax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) >>> solver = optax.chain( ... optax.sgd(learning_rate=0.1, momentum=0.9), ... optax.snapshot('norm_before_clip', lambda x: optax.tree.norm(x)), ... optax.clip_by_global_norm(0.05) ... ) >>> params = jnp.array([1., 2., 3.]) >>> state = solver.init(params) >>> for step in range(2): ... grads = jax.grad(f)(params) ... updates, state = solver.update(grads, state) ... params = optax.apply_updates(params, updates) ... norm = optax.tree.get(state, 'norm_before_clip') ... print(f'{step=}, {norm=}') step=0, norm=Array(0.7483, dtype=float32) step=1, norm=Array(1.4118, dtype=float32)
Compute a trace of past updates.
decay – Decay rate for the trace of past updates.
nesterov – Whether to use Nesterov momentum.
accumulator_dtype – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.
A optax.GradientTransformation
object.
Note
optax.trace()
and optax.ema()
have very similar but distinct updates; trace = decay * trace + t
, while ema = decay * ema + (1-decay) * t
. Both are frequently found in the optimization literature.
Holds an aggregation of past updates.
Compute the exponential moving average of the infinity norm.
Compute the exponential moving average of the order-th moment.
Compute the EMA of the order-th moment of the element-wise norm.
Wraps a gradient transformation, so that it ignores extra args.
A transformation which replaces NaNs with 0.
The state of the transformation has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update
. This state is not used by the transformation internally, but lets users be aware when NaNs have been zeroed out.
Contains a tree.
The entry found_nan has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update.
Information about the zoom linesearch step, exposed for debugging.
A positive curvature error is not stringent. It can be due to a maximal learning rate too small. A positive value in the sufficient curvature error is more problematic as it means that the algorithm may not be guaranteed to produce monotonically decreasing values. Consider using verbose=True
in scale_by_zoom_linesearch()
for additional failure diagnostics if the linesearch fails.
number of linesearch steps
int | jax.Array | numpy.ndarray | numpy.bool | numpy.number | float
sufficient decrease error. A positive value indicates that the linesearch failed to find a stepsize that ensures a sufficient decrease. A null value indicates it succeeded in finding such a stepsize.
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | int
small curvature error. A positive value indicates that the linesearch failed to find a stepsize that ensures a small curvature. A null value indicates it succeeded in finding such a stepsize.
float | jax.Array | numpy.ndarray | numpy.bool | numpy.number | int
Create a transformation that zeros out gradient updates for mask=True.
This essentially freezes (i.e. holding constant) masked parameters.
The mask must be static (i.e., not dependent on runtime values or updated during training) and can be:
a single boolean (or 0-d JAX bool array), causing every parameter to be either all-frozen (True) or all-trainable (False), or
a PyTree of booleans matching the structure of the parameters, where each leaf indicates whether that specific parameter leaf should be frozen (True) or left unchanged (False).
mask – A boolean prefix tree mask indicating which parameters to freeze.
Example
>>> import jax.numpy as jnp >>> from optax import freeze >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)} >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b' >>> freezer = freeze(mask)
An Optax GradientTransformation which applies set_to_zero() wherever mask==True, and leaves other gradients intact.
Partition updates so that only un-frozen parameters are optimized.
Example
>>> import jax.numpy as jnp >>> from optax import selective_transform >>> params = {'a': jnp.zeros(1), 'b': jnp.zeros(2)} >>> mask = {'a': True, 'b': False} # Freeze 'a', train 'b' >>> selective_opt = selective_transform(optax.adam(1e-3), freeze_mask=mask)
optimizer – The inner Optax optimizer to apply to unfrozen leaves.
freeze_mask – A static mask (i.e., not dependent on runtime values or
either (updated during training). It can be) –
a scalar bool (or 0-d JAX bool array) to freeze everything (True) or nothing (False)
a PyTree of booleans mirroring the parameter tree, marking each leaf to freeze (True) or train (False).
the given optimizer if its mask is False (“train”),
set_to_zero() if its mask is True (“freeze”).
A GradientTransformation that routes each parameter leaf through
See also
optax.freeze()
: For simply zeroing out gradients according to a mask.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4