Closure conversion utility, for use with higher-order custom derivatives.
To define custom derivatives such as with jax.custom_vjp(f)
, the target function f
must take, as formal arguments, all values involved in differentiation. If f
is a higher-order function, in that it accepts as an argument a Python function g
, then values stored away in g
’s closure will not be visible to the custom derivative rules, and attempts at AD involving these values will fail. One way around this is to convert the closure by extracting these values, and to pass them as explicit formal arguments across the custom derivative boundary. This utility carries out that conversion. More precisely, it closure-converts the function fun
specialized to the types of the arguments given in example_args
.
When we refer here to “values in the closure” of fun
, we do not mean the values that are captured by Python directly when fun
is defined (e.g. the Python objects in fun.__closure__
, if the attribute exists). Rather, we mean values encountered during the execution of fun
on example_args
that determine its output. This may include, for instance, arrays captured transitively in Python closures, i.e. in the Python closure of functions called by fun
, the closures of the functions that they call, and so forth.
The function fun
must be a pure function.
Example usage:
def minimize(objective_fn, x0): converted_fn, aux_args = closure_convert(objective_fn, x0) return _minimize(converted_fn, x0, *aux_args) @partial(custom_vjp, nondiff_argnums=(0,)) def _minimize(objective_fn, x0, *args): z = objective_fn(x0, *args) # ... find minimizer x_opt ... return x_opt def fwd(objective_fn, x0, *args): y = _minimize(objective_fn, x0, *args) return y, (y, args) def rev(objective_fn, res, g): y, args = res y_bar = g # ... custom reverse-mode AD ... return x0_bar, *args_bars _minimize.defvjp(fwd, rev)
fun (Callable) – Python callable to be converted. Must be a pure function.
example_args – Arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) thereof, used to determine the types of the formal arguments to fun
. This type-specialized form of fun
is the function that will be closure converted.
A pair comprising (i) a Python callable, accepting the same arguments as fun
followed by arguments corresponding to the values hoisted from its closure, and (ii) a list of values hoisted from the closure.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3