A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://jax.readthedocs.io/en/latest/_autosummary/jax.smap.html below:

jax.smap — JAX documentation

jax.smap#
jax.smap(f=None, /, *, in_axes=jax.sharding.Infer, out_axes, axis_name)[source]#

Single axis shard_map that maps a function f one axis at a time.

Parameters:
  • f – Callable to be mapped. Each application of f, or “instance” of f, takes as input a shard of the mapped-over arguments and produces a shard of the output.

  • in_axes – (optional) An integer, None, or sequence of values specifying which input array axes to map over. If not specified, smap will try to infer the axes from the arguments only under Explicit mode. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range [-ndim, ndim) for each array, where ndim is the number of dimensions (axes) of the corresponding input array.

  • out_axes – An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output.

  • axis_name (AxisName) – mesh axis name over which the function f is manual.

Returns:

A callable representing a mapped version of f, which accepts positional arguments corresponding to those of f and produces output corresponding to that of f.


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.5