Skip to content

jaxmg.potri¤

jaxmg.potri(a, T_A, mesh, in_specs, return_status=False, pad=True) ¤

Compute the inverse of a symmetric matrix using the multi-GPU potri native kernel.

Prepares inputs for the native potri_mg kernel and executes it via jax.ffi.ffi_call under jax.jit and jax.shard_map. Handles per-device padding driven by T_A and symmetrizes the result before returning.

Tip

If the shards of the matrix cannot be padded with tiles of size T_A (N / num_gpus % T_A != 0) we have to add padding to fit the last tile. This requires copying the matrix, which we want to avoid at all costs for large N. Make sure you pick T_A large enough (>=128) and such that it can evenly cover the shards. In principle, increasing T_A will increase performance at the cost of memory, but depending on N, the performance will saturate.

Parameters:

Name Type Description Default
a Array

A 2D JAX array of shape (N_rows, N). Must be symmetric and is expected to be sharded across the mesh along the first (row) axis using P(<axis_name>, None).

required
T_A int

Tile width used by the native solver. Each local shard length must be a multiple of T_A. If the user provides a T_A that is incompatible with the shard size we pad the matrix accordingly. For small tile sizes (T_A< 128), the solver can be extremely slow, so ensure that T_A is large enough. In principle, the larger T_A the faster the solver runs.

required
mesh Mesh

JAX device mesh used for jax.shard_map.

required
in_specs PartitionSpec or tuple / list[PartitionSpec]

PartitionSpec describing the input sharding (row sharding). May be provided as a single PartitionSpec or a single-element container containing one.

required
return_status bool

If True return (A_inv, status) where status is a host-replicated int32 from the native solver. If False return A_inv only. Default is False.

False
pad bool

If True (default) apply per-device padding to meet T_A requirements; if False the caller must supply already- padded shapes.

True

Returns:

Type Description
jax.Array | tuple[jax.Array, int]

Array or (Array, int): The inverted matrix (row-sharded). If return_status=True also return the native solver status code.

Raises:

Type Description
TypeError

If in_specs is not a PartitionSpec or a single- element container.

ValueError

If in_specs does not indicate row sharding (P(<axis_name>, None)).

AssertionError

If a is not 2D or if required shapes do not match when pad=False.

Notes
  • The FFI call is executed with donate_argnums=0 enabling zero-copy buffer sharing with the native library.
  • If the native solver fails the output may contain NaNs and status will be non-zero.

jaxmg.potri_shardmap_ctx(a, T_A, pad=True) ¤

Compute the inverse of a symmetric matrix for already-sharded inputs.

This helper is a lower-level variant of :func:potri intended for environments where the caller already manages sharding/device placement (for example inside a custom shard_map or other placement context). It performs the same per-device padding logic driven by T_A and calls the native potri_mg FFI target directly via jax.ffi.ffi_call.

Warning

On exit, we return the upper triangular part of A_inv. To achive the full inverse, call jaxmg.potri_symmetrize outside of the shardmap_context.

Tip

If the shards of the matrix cannot be padded with tiles of size T_A (N / num_gpus % T_A != 0) we have to add padding to fit the last tile. This requires copying the matrix, which we want to avoid at all costs for large N. Make sure you pick T_A large enough (>=128) and such that it can evenly cover the shards. In principle, increasing T_A will increase performance at the cost of memory, but depending on N, the performance will saturate.

Parameters:

Name Type Description Default
a Array

Local, row-sharded slice of the global matrix with shape (shard_size, N) where shard_size is the per-device local row count and N is the global matrix dimension. The matrix should be symmetric.

required
T_A int

Tile width used by the native solver. Each local shard length must be a multiple of T_A. If the user provides a T_A that is incompatible with the shard size we pad the matrix accordingly. For small tile sizes (T_A< 128), the solver can be extremely slow, so ensure that T_A is large enough. In principle, the larger T_A the faster the solver runs.

required
pad bool

If True (default) apply per-device padding to a to satisfy T_A. If False the caller must ensure the provided local shape already meets kernel requirements.

True

Returns:

Name Type Description
tuple jax.Array | tuple[jax.Array, int]

(A_inv, status) where A_inv contains the the upper triangular part of the inverted matrix (in the same local/sharded layout as the input, with padding removed if applied) and status is the int32 status value returned by the native kernel (shape (1,) device array).

Raises:

Type Description
AssertionError

If a is not 2D.

ValueError

If T_A is too large or shape requirements are violated when pad=False.

Notes
  • This function does not create an outer jax.shard_map or apply donate_argnums; it is intended for use when the caller already controls sharding and device placement.
  • Padding is handled with :func:calculate_padding, :func:pad_rows, and :func:unpad_rows.
  • If the native solver fails the output may contain NaNs and the returned status will be non-zero.