A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_jvp.defjvps.html below:

jax.custom_jvp.defjvps — JAX documentation

jax.custom_jvp.defjvps#
custom_jvp.defjvps(*jvps)[source]#

Convenience wrapper for defining JVPs for each argument separately.

This convenience wrapper cannot be used together with nondiff_argnums.

Parameters:

*jvps (Callable[..., ReturnValue] | None) – a sequence of functions, one for each positional argument of the custom_jvp function. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below.

Returns:

None.

Return type:

None

Examples

>>> @jax.custom_jvp
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
...           lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
>>> x = jnp.float32(1.0)
>>> y = jnp.float32(2.0)
>>> with jnp.printoptions(precision=2):
...   print(jax.value_and_grad(f)(x, y))
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.5