Skip to content

Cholesky decomposition with jaxmg.potrsยค

Here, we give an example of calling jax.potrs, which solves the linear system of equations \(Ax=b\) for symmetric, positive-definite \(A\) via a Cholesky decomposition.

The interface of jaxmg.potrs is simple to use; one needs to supply to underlying mesh of the sharded data and specify the input shardings:

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jaxmg import potrs
print(f"Devices: {jax.devices()}")
# Assumes we have at least one GPU available
devices = jax.devices("gpu")
N = 12
T_A = 3
dtype = jnp.float64
# Create diagonal matrix and `b` all equal to one
A = jnp.diag(jnp.arange(N, dtype=dtype) + 1)
b = jnp.ones((N, 1), dtype=dtype)
ndev = len(devices)
# Make mesh and place data (rows sharded)
mesh = jax.make_mesh((ndev,), ("x",))
A = jax.device_put(A, NamedSharding(mesh, P("x", None)))
b = jax.device_put(b, NamedSharding(mesh, P(None, None)))
# Call potrs
out = potrs(A, b, T_A=T_A, mesh=mesh, in_specs=(P("x", None), P(None, None)))
print(out)
expected_out = 1.0 / (jnp.arange(N, dtype=dtype) + 1)
print(jnp.allclose(out.flatten(), expected_out))
mkdir -p failed for path /home/rwiersema/.cache/matplotlib: [Errno 13] Permission denied: '/home/rwiersema'
Matplotlib created a temporary cache directory at /tmp/matplotlib-kkpjcj85 because there was an issue with the default path (/home/rwiersema/.cache/matplotlib); it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
Devices: [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2)]
[[1.        ]
 [0.5       ]
 [0.33333333]
 [0.25      ]
 [0.2       ]
 [0.16666667]
 [0.14285714]
 [0.125     ]
 [0.11111111]
 [0.1       ]
 [0.09090909]
 [0.08333333]]
True

We retrieve the inverse of the diagonal since we set \(b\) equal to a vector of ones.

The API jaxmg.potrs is useful if you are not working in a jax.shard_map context, since jaxmg.potrs takes care of the shard map for you. However, in a typical application we may already be working in a sharded context. We therefore provide an additional API that can be called from such a context.

from jaxmg import potrs_shardmap_ctx
from functools import partial

print(f"Devices: {jax.devices()}")

def shard_mapped_fn(_a, _b, _T_A):
    _a = _a * (jax.lax.axis_index("x")+1) # Multiply each shard with the axis number
    return potrs_shardmap_ctx(_a, _b, _T_A)

def my_fn(_a, _b, _T_A):
    out = jax.shard_map(
        partial(shard_mapped_fn, _T_A=_T_A),
        mesh=mesh,
        in_specs=(P("x", None), P(None, None)),
        out_specs=(P(None, None), P(None)), # we always return a status.
        check_vma=False,
    )(_a, _b)
    return out


# Assumes we have at least one GPU available
devices = jax.devices("gpu")
N = 6
T_A = 1
dtype = jnp.float32
# Create diagonal matrix
A = jnp.eye(N, dtype=dtype)
b = jnp.ones((N, 1), dtype=dtype)
ndev = len(devices)
# Make mesh and place data (rows sharded)
mesh = jax.make_mesh((ndev,), ("x",))
A = jax.device_put(A, NamedSharding(mesh, P("x", None)))
b = jax.device_put(b, NamedSharding(mesh, P(None, None)))
# Call potrs
out, status = my_fn(A, b, T_A)
print(out.block_until_ready())
print(f"Solver status: {status}")
Devices: [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2)]
[[1.        ]
 [1.        ]
 [0.49999997]
 [0.49999997]
 [0.3333333 ]
 [0.3333333 ]]
Solver status: [0]

After multiplying the diagonal with the axis number "x" we get the the expected solution \((1,1,\frac{1}{2},\frac{1}{2},\frac{1}{3},\frac{1}{3})\)