jax.experimental.jet
module#
Jet is an experimental module for higher-order automatic differentiation that does not rely on repeated first-order automatic differentiation.
How? Through the propagation of truncated Taylor polynomials. Consider a function \(f = g \circ h\), some point \(x\) and some offset \(v\). First-order automatic differentiation (such as jax.jvp()
) computes the pair \((f(x), \partial f(x)[v])\) from the pair \((h(x), \partial h(x)[v])\).
jet()
implements the higher-order analogue: Given the tuple
\[(h_0, ... h_K) := (h(x), \partial h(x)[v], \partial^2 h(x)[v, v], ..., \partial^K h(x)[v,...,v]),\]
which represents a \(K\)-th order Taylor approximation of \(h\) at \(x\), jet()
returns a \(K\)-th order Taylor approximation of \(f\) at \(x\),
\[(f_0, ..., f_K) := (f(x), \partial f(x)[v], \partial^2 f(x)[v, v], ..., \partial^K f(x)[v,...,v]).\]
More specifically, jet()
computes
\[f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))\]
and can thus be used for high-order automatic differentiation of \(f\). Details are explained in these notes.
API#Taylor-mode higher-order automatic differentiation.
fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
primals – The primal values at which the Taylor approximation of fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of fun
.
series – Higher order Taylor-series-coefficients. Together, primals and series make up a truncated Taylor polynomial. Should be either a tuple or a list of tuples or lists, and its length dictates the degree of the truncated Taylor polynomial.
A (primals_out, series_out)
pair, where primals_out
is fun(*primals)
, and together, primals_out
and series_out
are a truncated Taylor polynomial of \(f(h(\cdot))\). The primals_out
value has the same Python tree structure as primals
, and the series_out
value the same Python tree structure as series
.
For example:
>>> import jax >>> import jax.numpy as np
Consider the function \(h(z) = z^3\), \(x = 0.5\), and the first few Taylor coefficients \(h_0=x^3\), \(h_1=3x^2\), and \(h_2=6x\). Let \(f(y) = \sin(y)\).
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5 >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
jet()
returns the Taylor coefficients of \(f(h(z)) = \sin(z^3)\) according to Faà di Bruno’s formula:
>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),)) >>> print(f0, f(h0)) 0.12467473 0.12467473
>>> print(f1, df(h0) * h1) 0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) 2.9064622 2.9064634
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3