Creates a function that evaluates the gradient of fun
.
fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums
should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified by argnums
must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape ()
but not arrays with shape (1,)
etc.)
argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux (bool) – Optional, bool. Indicates whether fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
holomorphic (bool) – Optional, bool. Indicates whether fun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
reduce_axes (Sequence[AxisName])
A function with the same arguments as fun
, that evaluates the gradient of fun
. If argnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux
is True then a pair of (gradient, auxiliary_data) is returned.
Callable
For example:
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.5