Skip to content

jaxmg.potrs¤

jaxmg.potrs(a, b, T_A, mesh, in_specs, return_status=False, pad=True) ¤

Solve the linear system A x = B using the multi-GPU potrs native kernel.

Prepares inputs for the native potrs_mg kernel and executes it via jax.ffi.ffi_call under jax.jit and jax.shard_map. Handles per-device padding driven by T_A and returns the solution (and optionally a host-side solver status).

Tip

If the shards of the matrix cannot be padded with tiles of size T_A (N / num_gpus % T_A != 0) we have to add padding to fit the last tile. This requires copying the matrix, which we want to avoid at all costs for large N. Make sure you pick T_A large enough (>=128) and such that it can evenly cover the shards. In principle, increasing T_A will increase performance at the cost of memory, but depending on N, the performance will saturate.

Parameters:

Name Type Description Default
a Array

2D, symmetric matrix representing the coefficient matrix. Expected to be sharded across the mesh along the first (row) axis using a single PartitionSpec: P(<axis_name>, None).

required
b Array

2D right-hand side. Expected to be replicated across devices with PartitionSpec P(None, None).

required
T_A int

Tile width used by the native solver. Each local shard length must be a multiple of T_A. If the user provides a T_A that is incompatible with the shard size we pad the matrix accordingly. For small tile sizes (T_A< 128), the solver can be extremely slow, so ensure that T_A is large enough. In principle, the larger T_A the faster the solver runs.

required
mesh Mesh

JAX Mesh object used for jax.shard_map.

required
in_specs tuple[list][PartitionSpec]

The sharding specifications for (a, b). Expected to be (P(<axis_name>, None), P(None, None)).

required
return_status bool

If True return (x, status) where status is a host-replicated int32 from the native solver. If False return x only. Default is False.

False
pad bool

If True (default) apply per-device padding to a so each local shard length is compatible with T_A; if False the caller must ensure shapes already match the kernel's requirements.

True

Returns:

Type Description
jax.Array | tuple[jax.Array, int]

Array or (Array, int): The solution x (replicated across devices). If return_status=True also return the native solver status.

Raises:

Type Description
AssertionError

If a or b are not 2D, or their shapes are incompatible.

ValueError

If in_specs is not a 2-element sequence or if the provided PartitionSpec objects do not match the required patterns (P(<axis_name>, None) for a and P(None, None) for b).

Notes
  • The FFI call may donate the a buffer (donate_argnums=0) for zero-copy interaction with the native library.
  • If the native solver fails the returned solution may contain NaNs and status will be non-zero.

jaxmg.potrs_shardmap_ctx(a, b, T_A, pad=True) ¤

Solve A x = B by invoking the native multi-GPU potrs kernel without shard_map.

This helper is a lightweight, lower-level variant of :func:jaxmg.potrs intended for contexts where the input a is already laid out and sharded at the application level (for example when running inside a custom shard_map/pjit-managed context). It performs the same padding logic driven by T_A and directly calls the native potrs_mg FFI targets via jax.ffi.ffi_call instead of constructing an additional shard_map wrapper.

Tip

If the shards of the matrix cannot be padded with tiles of size T_A (N / num_gpus % T_A != 0) we have to add padding to fit the last tile. This requires copying the matrix, which we want to avoid at all costs for large N. Make sure you pick T_A large enough (>=128) and such that it can evenly cover the shards. In principle, increasing T_A will increase performance at the cost of memory, but depending on N, the performance will saturate.

Parameters:

Name Type Description Default
a Array

2D coefficient matrix of shape (N_rows // ndev, N). Must be symmetric for correct solver behavior.

required
b Array

2D right-hand side. Its first dimension must equal the number of columns of a (i.e. a.shape[1] == b.shape[0]).

required
T_A int

Tile width used by the native solver. Each local shard length must be a multiple of T_A. If the user provides a T_A that is incompatible with the shard size we pad the matrix accordingly. For small tile sizes (T_A< 128), the solver can be extremely slow, so ensure that T_A is large enough. In principle, the larger T_A the faster the solver runs.

required
pad bool

If True (default) apply per-device padding to a so each local shard length is compatible with T_A. If False the caller must ensure shapes already meet the kernel's requirements.

True

Returns:

Name Type Description
tuple tuple[jax.Array, jax.Array]

(x, status) where x is the solver result (same shape as b) and status is the int32 status value returned by the native kernel (shape (1,) device array).

Raises:

Type Description
AssertionError

If input arrays are not 2D or their shapes are incompatible.

Notes
  • This function does not perform sharding via jax.shard_map and therefore must be called only in a shard_map context.
  • Because it does not use donate_argnums, the input buffers are not donated to the FFI call (no zero-copy donation semantics).