Convenience wrapper for defining JVPs for each argument separately.
This convenience wrapper cannot be used together with nondiff_argnums
.
*jvps (Callable[..., ReturnValue] | None) – a sequence of functions, one for each positional argument of the custom_jvp
function. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below.
None.
None
Examples
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, ... lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.5