jax.sharding
module# Classes#
Describes how a jax.Array
is laid out across devices.
The set of devices in the Sharding
that are addressable by the current process.
A mapping from addressable devices to the slice of array data each contains.
addressable_devices_indices_map
contains that part of device_indices_map
that applies to the addressable devices.
global_shape (Shape)
Mapping[Device, Index | None]
The set of devices that this Sharding
spans.
In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
global_shape (Shape)
Mapping[Device, Index]
Returns True
if two shardings are equivalent.
Two shardings are equivalent if they place the same logical array shards on the same devices.
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the Sharding
. is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
Returns the memory kind of the sharding.
Number of devices that the sharding contains.
Returns the shape of the data on each device.
The shard shape returned by this function is calculated from global_shape
and the properties of the sharding.
global_shape (Shape)
Shape
Bases: Sharding
A Sharding
that places its data on a single device.
device – A single Device
.
Examples
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
The set of devices that this Sharding
spans.
In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
global_shape (Shape)
Mapping[Device, Index]
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the Sharding
. is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
Returns the memory kind of the sharding.
Number of devices that the sharding contains.
Returns a new Sharding instance with the specified memory kind.
kind (str)
Bases: Sharding
A NamedSharding
expresses sharding using named axes.
A NamedSharding
is a pair of a Mesh
of devices and PartitionSpec
which describes how to shard an array across that mesh.
A Mesh
is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g. 'x'
or 'y'
.
A PartitionSpec
is a tuple, whose elements can be a None
, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, PartitionSpec('x', 'y')
says that the first dimension of data is sharded across x
axis of the mesh, and the second dimension is sharded across y
axis of the mesh.
The Distributed arrays and automatic parallelization (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how Mesh
and PartitionSpec
are used.
mesh – A jax.sharding.Mesh
object.
spec – A jax.sharding.PartitionSpec
object.
Examples
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
The set of devices in the Sharding
that are addressable by the current process.
The set of devices that this Sharding
spans.
In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the Sharding
. is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
Returns the memory kind of the sharding.
(self) -> object
Number of devices that the sharding contains.
(self) -> object
Returns a new Sharding instance with the specified memory kind.
kind (str)
Bases: Sharding
Describes a sharding used by jax.pmap()
.
Creates a PmapSharding
which matches the default placement used by jax.pmap()
.
shape (Shape) – The shape of the input array.
sharded_dim (int | None) – Dimension the input array is sharded on. Defaults to 0.
devices (Sequence[xc.Device] | None) – Optional sequence of devices to use. If omitted, the implicit device order used by pmap is used, which is the order of jax.local_devices()
.
The set of devices that this Sharding
spans.
In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.
(self) -> ndarray
Returns a mapping from devices to the array slices each contains.
The mapping includes all global devices, i.e., including non-addressable devices from other processes.
global_shape (Shape)
Mapping[Device, Index]
Returns True
if two shardings are equivalent.
Two shardings are equivalent if they place the same logical array shards on the same devices.
self (PmapSharding)
other (PmapSharding)
ndim (int)
Is this sharding fully addressable?
A sharding is fully addressable if the current process can address all of the devices named in the Sharding
. is_fully_addressable
is equivalent to “is_local” in multi-process JAX.
Is this sharding fully replicated?
A sharding is fully replicated if each device has a complete copy of the entire data.
Returns the memory kind of the sharding.
Number of devices that the sharding contains.
Returns the shape of the data on each device.
The shard shape returned by this function is calculated from global_shape
and the properties of the sharding.
global_shape (Shape)
Shape
(self) -> jax::ShardingSpec
Returns a new Sharding instance with the specified memory kind.
kind (str)
Tuple describing how to partition an array across a mesh of devices.
Each element is either None
, a string, or a tuple of strings. See the documentation of jax.sharding.NamedSharding
for more details.
This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.
(self) -> frozenset
(self) -> frozenset
Declare the hardware resources available in the scope of this manager.
See the Distributed arrays and automatic parallelization tutorial (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and Explicit sharding tutorial (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)
devices (np.ndarray) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from jax.devices()
).
axis_names (tuple[MeshAxisName, ...]) – A sequence of resource axis names to be assigned to the dimensions of the devices
argument. Its length should match the rank of devices
.
axis_types (tuple[AxisType, ...] | None) – and optional tuple of jax.sharding.AxisType
entries corresponding to the axis_names
. See Explicit Sharding for more information.
Examples
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> devices = np.array(jax.devices()).reshape(4, 2) >>> mesh = Mesh(devices, ('x', 'y')) >>> inp = np.arange(16).reshape(8, 2) >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) >>> out = jax.jit(lambda x: x * 2)(arr) >>> assert out.sharding == NamedSharding(mesh, P('x', 'y'))
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4