MPMD supportยค
We support two modes of distributed computing. Single Process Multiple Devices mode (SPMD), where we have
a single process per node that potentially manages multiple devices. We also support MPMD mode (MPMD). Here the user needs to
use one process for each GPU. When jaxmg is imported, we attempt to verify the user's distributed setup to not go out beyond these two modes of computation.
jaxmg supports multi-process jax.distributed environments but cuSolverMg can only run on a single node. There are some technical reasons for this that will hopefully be resolved in a future release.
To circumvent this limitation, one can perform a computation over all global devices, replicate the results over all host by gathering the data and calling the solver only on each machine.
Here we provide an example of using jaxmg in a context where we have 2 nodes, each with 4 GPUs.
In order to use the solver, we will have to gather the results onto each node by making use of a 2D Mesh.
# Call ./examples/multi_process.sh to launch this code!
import os
import sys
proc_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
num_procs = int(sys.argv[2]) if len(sys.argv) > 2 else 1
# initialize the distributed system
import jax
jax.config.update("jax_platform_name", "cpu")
jax.distributed.initialize("localhost:6000", num_procs, proc_id)
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax.numpy as jnp
import numpy as np
def get_device_grid():
by_proc = {}
for d in jax.devices():
by_proc.setdefault(d.process_index, []).append(d)
hosts = sorted(by_proc)
return np.array(
[[by_proc[h][x] for x in range(jax.local_device_count())] for h in hosts]
)
def create_2d_mesh():
dev_grid = get_device_grid()
return Mesh(dev_grid, ("x", "y"))
def create_1d_mesh():
dev_grid = get_device_grid()
return Mesh(dev_grid.flatten(), ("y",))
print(f"Rank {proc_id}")
print(f"Local devices {jax.local_device_count()}")
print(f"Global devices {jax.device_count()}")
print(f"World size {num_procs}")
print(f"Device grid{get_device_grid()}")
When we launch this code like this:
#!/bin/bash
export JAX_NUM_CPU_DEVICES=4
num_processes=2
range=$(seq 0 $(($num_processes - 1)))
for i in $range; do
python multi_process.py $i $num_processes > /tmp/multi_process_$i.out &
done
wait
for i in $range; do
echo "=================== process $i output ==================="
cat /tmp/multi_process_$i.out
echo
done
=================== process 0 output ===================
Rank 0
Local devices 4
Global devices 8
World size 2
Device grid
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
[CpuDevice(id=131072) CpuDevice(id=131073) CpuDevice(id=131074) CpuDevice(id=131075)]]
=================== process 1 output ===================
Rank 1
Local devices 4
Global devices 8
World size 2
Device grid
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
[CpuDevice(id=131072) CpuDevice(id=131073) CpuDevice(id=131074) CpuDevice(id=131075)]]
mesh2d = create_2d_mesh()
A = jax.device_put(
jnp.diag(jnp.arange(1, jax.device_count() + 1, dtype=jnp.float32)),
NamedSharding(mesh2d, P(None, ("x", "y"))),
)
for shard in A.addressable_shards:
print(f"shard\n {shard.data}")
# Gather over the number of hosts
A = jax.lax.with_sharding_constraint(A, NamedSharding(mesh2d, P(None, "y")))
for shard in A.addressable_shards:
print(f"shard\n {shard.data}")
which prints
=================== process 0 output ===================
Rank 0
Local devices 4
Global devices 8
World size 2
Device grid
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
[CpuDevice(id=131072) CpuDevice(id=131073) CpuDevice(id=131074)
CpuDevice(id=131075)]]
shard
[[1.]
[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
[0.]]
...
shard
[[0.]
[0.]
[0.]
[4.]
[0.]
[0.]
[0.]
[0.]]
shard
[[1. 0.]
[0. 2.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]
...
shard
[[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[7. 0.]
[0. 8.]]
=================== process 1 output ===================
...
shard
[[0.]
[0.]
[0.]
[0.]
[5.]
[0.]
[0.]
[0.]]
shard
...
shard
[[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
[0.]
[8.]]
shard
[[1. 0.]
[0. 2.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]
...
shard
[[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[7. 0.]
[0. 8.]]
We went from a matrix that was column sharded over all 8 global devices, to a matrix that was column sharded over the 4 gpus in each process.
In this host-replicated layout we can safely call jaxmg.potrs on the array with a 2D mesh (the code below only works if we are actually performing this computation with access to GPUs):
from jaxmg import potrs
out= potrs(
A,
jnp.ones((jax.device_count(), 1), dtype=jnp.float32),
T_A=256,
mesh=mesh2d,
in_specs=(P(None, "T"), P(None, None)),
)