Define a custom JVP rule for the function represented by this instance.
jvp (Callable[..., tuple[ReturnValue, ReturnValue]]) – a Python callable representing the custom JVP rule. When there are no nondiff_argnums
, the jvp
function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of parameters of the custom_jvp
function. The jvp
function should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof.
symbolic_zeros (bool) – boolean, indicating whether the rule should be passed objects representing static symbolic zeros in its tangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed. Setting this option to True
allows a JVP rule to detect whether certain inputs are not involved in differentiation, but at the cost of needing special handling for these objects (which e.g. can’t be passed into jax.numpy functions). Default False
.
Returns jvp
so that defjvp
can be used as a decorator.
Callable[…, tuple[ReturnValue, ReturnValue]]
Examples
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.defjvp ... def f_jvp(primals, tangents): ... x, y = primals ... x_dot, y_dot = tangents ... primal_out = f(x, y) ... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot ... return primal_out, tangent_out
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.5