This tutorial serves as an introduction to device parallelism for Single-Program Multi-Data (SPMD) code in JAX. SPMD is a parallelism technique where the same computation, such as the forward pass of a neural network, can be run on different input data (for example, different inputs in a batch) in parallel on different devices, such as several GPUs or Google TPUs.
The tutorial covers three modes of parallel computation:
Automatic sharding via jax.jit()
: The compiler chooses the optimal computation strategy (a.k.a. “the compiler takes the wheel”).
Explicit Sharding (*new*) is similar to automatic sharding in that you’re writing a global-view program. The difference is that the sharding of each array is part of the array’s JAX-level type making it an explicit part of the programming model. These shardings are propagated at the JAX level and queryable at trace time. It’s still the compiler’s responsibility to turn the whole-array program into per-device programs (turning jnp.sum
into psum
for example) but the compiler is heavily constrained by the user-supplied shardings.
Fully manual sharding with manual control using jax.shard_map()
: shard_map
enables per-device code and explicit communication collectives
A summary table:
Mode
View?
Explicit sharding?
Explicit Collectives?
Auto
Global
❌
❌
Explicit
Global
✅
❌
Manual
Per-device
✅
✅
Using these schools of thought for SPMD, you can transform a function written for one device into a function that can run in parallel on multiple devices.
import jax jax.config.update('jax_num_cpu_devices', 8)
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]Key concept: Data sharding#
Key to all of the distributed computation approaches below is the concept of data sharding, which describes how data is laid out on the available devices.
How can JAX understand how the data is laid out across devices? JAX’s datatype, the jax.Array
immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The jax.Array
object is designed with distributed data and computation in mind. Every jax.Array
has an associated jax.sharding.Sharding
object, which describes which shard of the global data is required by each global device. When you create a jax.Array
from scratch, you also need to create its Sharding
.
In the simplest cases, arrays are sharded on a single device, as demonstrated below:
import numpy as np import jax.numpy as jnp arr = jnp.arange(32.0).reshape(4, 8) arr.devices()
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)
For a more visual representation of the storage layout, the jax.debug
module provides some helpers to visualize the sharding of an array. For example, jax.debug.visualize_array_sharding()
displays how the array is stored in memory of a single device:
jax.debug.visualize_array_sharding(arr)
To create an array with a non-trivial sharding, you can define a jax.sharding
specification for the array and pass this to jax.device_put()
.
Here, define a NamedSharding
, which specifies an N-dimensional grid of devices with named axes, where jax.sharding.Mesh
allows for precise device placement:
from jax.sharding import PartitionSpec as P mesh = jax.make_mesh((2, 4), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P('x', 'y')) print(sharding)
NamedSharding(mesh=Mesh('x': 2, 'y': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)
Passing this Sharding
object to jax.device_put()
, you can obtain a sharded array:
arr_sharded = jax.device_put(arr, sharding) print(arr_sharded) jax.debug.visualize_array_sharding(arr_sharded)
[[ 0. 1. 2. 3. 4. 5. 6. 7.] [ 8. 9. 10. 11. 12. 13. 14. 15.] [16. 17. 18. 19. 20. 21. 22. 23.] [24. 25. 26. 27. 28. 29. 30. 31.]]
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 71. Automatic parallelism via
jit
#
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a jax.jit()
-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
The XLA compiler behind jit
includes heuristics for optimizing computations across multiple devices. In the simplest of cases, those heuristics boil down to computation follows data.
To demonstrate how auto-parallelization works in JAX, below is an example that uses a jax.jit()
-decorated staged-out function: it’s a simple element-wise function, where the computation for each shard will be performed on the device associated with that shard, and the output is sharded in the same way:
@jax.jit def f_elementwise(x): return 2 * jnp.sin(x) + 1 result = f_elementwise(arr_sharded) print("shardings match:", result.sharding == arr_sharded.sharding)
As computations get more complex, the compiler makes decisions about how to best propagate the sharding of the data.
Here, you sum along the leading axis of x
, and visualize how the result values are stored across multiple devices (with jax.debug.visualize_array_sharding()
):
@jax.jit def f_contract(x): return x.sum(axis=0) result = f_contract(arr_sharded) jax.debug.visualize_array_sharding(result) print(result)
CPU 0,4 CPU 1,5 CPU 2,6 CPU 3,7
[48. 52. 56. 60. 64. 68. 72. 76.]
The result is partially replicated: that is, the first two elements of the array are replicated on devices 0
and 4
, the second on 1
and 5
, and so on.
The main idea behind explicit shardings, (a.k.a. sharding-in-types), is that the JAX-level type of a value includes a description of how the value is sharded. We can query the JAX-level type of any JAX value (or Numpy array, or Python scalar) using jax.typeof
:
some_array = np.arange(8) print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
JAX-level type of some_array: int32[8]
Importantly, we can query the type even while tracing under a jit
(the JAX-level type is almost defined as “the information about a value we have access to while under a jit).
@jax.jit def foo(x): print(f"JAX-level type of x during tracing: {jax.typeof(x)}") return x + x foo(some_array)
JAX-level type of x during tracing: int32[8]
Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)
To start seeing shardings in the type we need to set up an explicit-sharding mesh.
from jax.sharding import AxisType mesh = jax.make_mesh((2, 4), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit))
Now we can create some sharded arrays:
replicated_array = np.arange(8).reshape(4, 2) sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None))) print(f"replicated_array type: {jax.typeof(replicated_array)}") print(f"sharded_array type: {jax.typeof(sharded_array)}")
replicated_array type: int32[4,2] sharded_array type: int32[4@X,2]
We should read the type f32[4@X, 2]
as “a 4-by-2 array of 32-bit floats whose first dimension is sharded along mesh axis ‘X’. The array is replicated along all other mesh axes”
These shardings associated with JAX-level types propagate through operations. For example:
arg0 = jax.device_put(np.arange(4).reshape(4, 1), jax.NamedSharding(mesh, P("X", None))) arg1 = jax.device_put(np.arange(8).reshape(1, 8), jax.NamedSharding(mesh, P(None, "Y"))) @jax.jit def add_arrays(x, y): ans = x + y print(f"x sharding: {jax.typeof(x)}") print(f"y sharding: {jax.typeof(y)}") print(f"ans sharding: {jax.typeof(ans)}") return ans with jax.sharding.use_mesh(mesh): add_arrays(arg0, arg1)
x sharding: int32[4@X,1] y sharding: int32[1,8@Y] ans sharding: int32[4@X,8@Y]
That’s the gist of it. Shardings propagate deterministically at trace time and we can query them at trace time.
3. Manual parallelism withshard_map
#
In the automatic parallelism methods explored above, you can write a function as if you’re operating on the full dataset, and jit
will split that computation across multiple devices. By contrast, with jax.shard_map()
you write the function that will handle a single shard of data, and shard_map
will construct the full function.
shard_map
works by mapping a function across a particular mesh of devices (shard_map
maps over shards). In the example below:
As before, jax.sharding.Mesh
allows for precise device placement, with the axis names parameter for logical and physical axis names.
The in_specs
argument determines the shard sizes. The out_specs
argument identifies how the blocks are assembled back together.
Note: jax.shard_map()
code can work inside jax.jit()
if you need it.
mesh = jax.make_mesh((8,), ('x',)) f_elementwise_sharded = jax.shard_map( f_elementwise, mesh=mesh, in_specs=P('x'), out_specs=P('x')) arr = jnp.arange(32) f_elementwise_sharded(arr)
Array([ 1. , 2.682942 , 2.818595 , 1.28224 , -0.513605 , -0.9178486 , 0.44116902, 2.3139732 , 2.9787164 , 1.824237 , -0.08804226, -0.99998045, -0.07314587, 1.840334 , 2.9812148 , 2.3005757 , 0.42419338, -0.92279494, -0.50197446, 1.2997544 , 2.8258905 , 2.6733112 , 0.98229736, -0.69244087, -0.81115675, 0.7352965 , 2.525117 , 2.912752 , 1.5418116 , -0.32726777, -0.97606325, 0.19192469], dtype=float32)
The function you write only “sees” a single batch of the data, which you can check by printing the device local shape:
x = jnp.arange(32) print(f"global shape: {x.shape=}") def f(x): print(f"device local shape: {x.shape=}") return x * 2 y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
global shape: x.shape=(32,) device local shape: x.shape=(4,)
Because each of your functions only “sees” the device-local part of the data, it means that aggregation-like functions require some extra thought.
For example, here’s what a shard_map
of a jax.numpy.sum()
looks like:
def f(x): return jnp.sum(x, keepdims=True) jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)
Your function f
operates separately on each shard, and the resulting summation reflects this.
If you want to sum across shards, you need to explicitly request it using collective operations like jax.lax.psum()
:
def f(x): sum_in_shard = x.sum() return jax.lax.psum(sum_in_shard, 'x') jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
Because the output no longer has a sharded dimension, set out_specs=P()
(recall that the out_specs
argument identifies how the blocks are assembled back together in shard_map
).
With these concepts fresh in our mind, let’s compare the three approaches for a simple neural network layer.
Start by defining your canonical function like this:
@jax.jit def layer(x, weights, bias): return jax.nn.sigmoid(x @ weights + bias)
import numpy as np rng = np.random.default_rng(0) x = rng.normal(size=(32,)) weights = rng.normal(size=(32, 4)) bias = rng.normal(size=(4,)) layer(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)
You can automatically run this in a distributed manner using jax.jit()
and passing appropriately sharded data.
If you shard the leading axis of both x
and make weights
fully replicated, then the matrix multiplication will automatically happen in parallel:
mesh = jax.make_mesh((8,), ('x',)) x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x'))) weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P())) layer(x_sharded, weights_sharded, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)
Alternatively, you can use explicit sharding mode too:
explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,)) x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X'))) weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P())) @jax.jit def layer_auto(x, weights, bias): print(f"x sharding: {jax.typeof(x)}") print(f"weights sharding: {jax.typeof(weights)}") print(f"bias sharding: {jax.typeof(bias)}") out = layer(x, weights, bias) print(f"out sharding: {jax.typeof(out)}") return out with jax.sharding.use_mesh(explicit_mesh): layer_auto(x_sharded, weights_sharded, bias)
x sharding: float32[32@X] weights sharding: float32[32,4] bias sharding: float32[4] out sharding: float32[4]
Finally, you can do the same thing with shard_map
, using jax.lax.psum()
to indicate the cross-shard collective required for the matrix product:
from functools import partial @jax.jit @partial(jax.shard_map, mesh=mesh, in_specs=(P('x'), P('x', None), P(None)), out_specs=P(None)) def layer_sharded(x, weights, bias): return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias) layer_sharded(x, weights, bias)
Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)Next steps#
This tutorial serves as a brief introduction of sharded and parallel computation in JAX.
To learn about each SPMD method in-depth, check out these docs:
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3