Set up a JAX-transformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like jax.jvp()
or jax.grad()
) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.
There are two instance methods available for defining the custom JVP rule: defjvp()
for defining a single custom JVP rule for all the function’s inputs, and for convenience defjvps()
, which wraps defjvp()
, and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.
For example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.
Methods
__init__
(fun[, nondiff_argnums, ...])
defjvp
(jvp[, symbolic_zeros])
Define a custom JVP rule for the function represented by this instance.
defjvps
(*jvps)
Convenience wrapper for defining JVPs for each argument separately.
Attributes
jvp
symbolic_zeros
fun
nondiff_argnums
nondiff_argnames
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.5