jax
package# Subpackages# Configuration# Just-in-time compilation (jit
)#
jit
(fun, /, *[, in_shardings, ...])
Sets up fun
for just-in-time compilation with XLA.
disable_jit
([disable])
Context manager that disables jit()
behavior under its dynamic context.
Context manager to ensure evaluation at trace/compile time (or error).
make_jaxpr
([axis_env, return_shape, ...])
Create a function that returns the jaxpr of fun
given example args.
eval_shape
(fun, *args, **kwargs)
Compute the shape/dtype of fun
without any FLOPs.
ShapeDtypeStruct
(shape, dtype, *[, ...])
A container for the shape, dtype, and other static attributes of an array.
device_put
(x[, device, src, donate, may_alias])
Transfers x
to device
.
device_get
(x)
Transfer x
to host.
Returns the platform name of the default XLA backend.
named_call
(fun, *[, name])
Adds a user specified name to a function when staging out JAX computations.
named_scope
(name)
A context manager that adds a user specified name to the JAX name stack.
Tries to call a block_until_ready
method on pytree leaves.
Tries to call a copy_to_host_async
method on pytree leaves.
make_mesh
(axis_shapes, axis_names, *[, ...])
Creates an efficient mesh with the shape and axis names specified.
Automatic differentiation#grad
(fun[, argnums, has_aux, holomorphic, ...])
Creates a function that evaluates the gradient of fun
.
value_and_grad
(fun[, argnums, has_aux, ...])
Create a function that evaluates both fun
and the gradient of fun
.
jacobian
(fun[, argnums, has_aux, ...])
Alias of jax.jacrev()
.
jacfwd
(fun[, argnums, has_aux, holomorphic])
Jacobian of fun
evaluated column-by-column using forward-mode AD.
jacrev
(fun[, argnums, has_aux, holomorphic, ...])
Jacobian of fun
evaluated row-by-row using reverse-mode AD.
hessian
(fun[, argnums, has_aux, holomorphic])
Hessian of fun
as a dense array.
jvp
(fun, primals, tangents[, has_aux])
Computes a (forward-mode) Jacobian-vector product of fun
.
Produces a linear approximation to fun
using jvp()
and partial eval.
linear_transpose
(fun, *primals[, reduce_axes])
Transpose a function that is promised to be linear.
vjp
() ))
Compute a (reverse-mode) vector-Jacobian product of fun
.
custom_gradient
(fun)
Convenience function for defining custom VJP rules (aka custom gradients).
closure_convert
(fun, *example_args)
Closure conversion utility, for use with higher-order custom derivatives.
checkpoint
(fun, *[, prevent_cse, policy, ...])
Make fun
recompute internal linearization points when differentiated.
vmap
)#
vmap
(fun[, in_axes, out_axes, axis_name, ...])
Vectorizing map.
numpy.vectorize
(pyfunc, *[, excluded, signature])
Define a vectorized function with broadcasting.
Parallelization (pmap
)#
shard_map
([f, axis_names, in_specs, mesh, ...])
Map a function over shards of data using a mesh of devices.
pmap
(fun[, axis_name, in_axes, out_axes, ...])
Parallel map with support for collective operations.
devices
([backend])
Returns a list of all devices for a given backend.
local_devices
([process_index, backend, host_id])
Like jax.devices()
, but only returns devices local to a given process.
process_index
([backend])
Returns the integer process index of this process.
device_count
([backend])
Returns the total number of devices.
local_device_count
([backend])
Returns the number of devices addressable by this process.
process_count
([backend])
Returns the number of JAX processes associated with the backend.
process_indices
([backend])
Returns the list of all JAX process indices associated with the backend.
Customization#custom_jvp
#
custom_jvp
(fun[, nondiff_argnums])
Set up a JAX-transformable function for a custom JVP rule definition.
custom_jvp.defjvp
(jvp[, symbolic_zeros])
Define a custom JVP rule for the function represented by this instance.
custom_jvp.defjvps
(*jvps)
Convenience wrapper for defining JVPs for each argument separately.
custom_vjp
#
custom_vjp
(fun[, nondiff_argnums])
Set up a JAX-transformable function for a custom VJP rule definition.
custom_vjp.defvjp
(fwd, bwd[, ...])
Define a custom VJP rule for the function represented by this instance.
custom_batching
# jax.Array (jax.Array
)# Array properties and methods#
List of addressable shards.
Array.all
([axis, out, keepdims, where])
Test whether all array elements along a given axis evaluate to True.
Array.any
([axis, out, keepdims, where])
Test whether any array elements along a given axis evaluate to True.
Array.argmax
([axis, out, keepdims])
Return the index of the maximum value.
Array.argmin
([axis, out, keepdims])
Return the index of the minimum value.
Array.argpartition
(kth[, axis])
Return the indices that partially sort the array.
Array.argsort
([axis, kind, order, stable, ...])
Return the indices that sort the array.
Array.astype
(dtype[, copy, device])
Copy the array and cast to a specified dtype.
Helper property for index update functionality.
Array.choose
(choices[, out, mode])
Construct an array choosing from elements of multiple arrays.
Array.clip
([min, max])
Return an array whose values are limited to a specified range.
Array.compress
(condition[, axis, out, size, ...])
Return selected slices of this array along given axis.
Whether the array is committed or not.
Return the complex conjugate of the array.
Return the complex conjugate of the array.
Return a copy of the array.
Copies an Array
to the host asynchronously.
Array.cumprod
([axis, dtype, out])
Return the cumulative product of the array.
Array.cumsum
([axis, dtype, out])
Return the cumulative sum of the array.
Array API-compatible device attribute.
Array.diagonal
([offset, axis1, axis2])
Return the specified diagonal from the array.
Array.dot
(b, *[, precision, ...])
Compute the dot product of two arrays.
The data type (numpy.dtype
) of the array.
Use flatten()
instead.
Array.flatten
([order, out_sharding])
Flatten array into a 1-dimensional shape.
List of global shards.
Return the imaginary part of the array.
Is this Array fully addressable?
Is this Array fully replicated?
Array.item
(*args)
Copy an element of an array to a standard Python scalar and return it.
Length of one array element in bytes.
Array.max
([axis, out, keepdims, initial, where])
Return the maximum of array elements along a given axis.
Array.mean
([axis, dtype, out, keepdims, where])
Return the mean of array elements along a given axis.
Array.min
([axis, out, keepdims, initial, where])
Return the minimum of array elements along a given axis.
Total bytes consumed by the elements of the array.
The number of dimensions in the array.
Array.nonzero
(*[, fill_value, size])
Return indices of nonzero elements of an array.
Array.prod
([axis, dtype, out, keepdims, ...])
Return product of the array elements over a given axis.
Array.ptp
([axis, out, keepdims])
Return the peak-to-peak range along a given axis.
Array.ravel
([order, out_sharding])
Flatten array into a 1-dimensional shape.
Return the real part of the array.
Array.repeat
(repeats[, axis, ...])
Construct an array from repeated elements.
Array.reshape
(*args[, order, out_sharding])
Returns an array containing the same data with a new shape.
Array.round
([decimals, out])
Round array elements to a given decimal.
Array.searchsorted
(v[, side, sorter, method])
Perform a binary search within a sorted array.
The shape of the array.
The sharding for the array.
The total number of elements in the array.
Array.sort
([axis, kind, order, stable, ...])
Return a sorted copy of an array.
Array.squeeze
([axis])
Remove one or more length-1 axes from array.
Array.std
([axis, dtype, out, ddof, ...])
Compute the standard deviation along a given axis.
Array.sum
([axis, dtype, out, keepdims, ...])
Sum of the elements of the array over a given axis.
Array.swapaxes
(axis1, axis2)
Swap two axes of an array.
Array.take
(indices[, axis, out, mode, ...])
Take elements from an array.
Array.to_device
(device, *[, stream])
Return a copy of the array on the specified device
Array.trace
([offset, axis1, axis2, dtype, out])
Return the sum along the diagonal.
Array.transpose
(*args)
Returns a copy of the array with axes transposed.
Array.var
([axis, dtype, out, ddof, ...])
Compute the variance along a given axis.
Array.view
([dtype, type])
Return a bitwise copy of the array, viewed as a new dtype.
Compute the all-axis array transpose.
Compute the (batched) matrix transpose.
Callbacks#pure_callback
(callback, result_shape_dtypes, ...)
Calls a pure Python callback.
experimental.io_callback
(callback, ...[, ...])
Calls an impure Python callback.
debug.callback
(callback, *args[, ordered, ...])
Calls a stageable Python callback.
debug.print
(fmt, *args[, ordered, partitioned])
Prints values and works in staged out JAX functions.
Miscellaneous#A descriptor of an available device.
print_environment_info
([return_string])
Returns a string containing local environment & JAX installation information.
live_arrays
([platform])
Return all live arrays in the backend for platform.
Clear all compilation and staging caches.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3