jaxmg.potri¤
jaxmg.potri(a, T_A, mesh, in_specs, return_status=False, pad=True)
¤
Compute the inverse of a symmetric matrix using the multi-GPU potri native kernel.
Prepares inputs for the native potri_mg kernel and executes it via
jax.ffi.ffi_call under jax.jit and jax.shard_map. Handles
per-device padding driven by T_A and symmetrizes the result before
returning.
Tip
If the shards of the matrix cannot be padded with tiles of size T_A
(N / num_gpus % T_A != 0) we have to add padding to fit the last tile.
This requires copying the matrix, which we want to avoid at all costs for
large N. Make sure you pick T_A large enough (>=128) and such that it
can evenly cover the shards. In principle, increasing T_A will increase
performance at the cost of memory, but depending on N, the performance
will saturate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Array
|
A 2D JAX array of shape |
required |
T_A
|
int
|
Tile width used by the native solver. Each
local shard length must be a multiple of |
required |
mesh
|
Mesh
|
JAX device mesh used for |
required |
in_specs
|
PartitionSpec or tuple / list[PartitionSpec]
|
PartitionSpec
describing the input sharding (row sharding). May be provided as a
single |
required |
return_status
|
bool
|
If True return |
False
|
pad
|
bool
|
If True (default) apply per-device padding to meet
|
True
|
Returns:
| Type | Description |
|---|---|
jax.Array | tuple[jax.Array, int]
|
Array or (Array, int): The inverted matrix (row-sharded). If
|
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
ValueError
|
If |
AssertionError
|
If |
Notes
- The FFI call is executed with
donate_argnums=0enabling zero-copy buffer sharing with the native library. - If the native solver fails the output may contain NaNs and
statuswill be non-zero.
jaxmg.potri_shardmap_ctx(a, T_A, pad=True)
¤
Compute the inverse of a symmetric matrix for already-sharded inputs.
This helper is a lower-level variant of :func:potri intended for
environments where the caller already manages sharding/device placement
(for example inside a custom shard_map or other placement context).
It performs the same per-device padding logic driven by T_A and calls
the native potri_mg FFI target directly via jax.ffi.ffi_call.
Warning
On exit, we return the upper triangular part of A_inv.
To achive the full inverse, call jaxmg.potri_symmetrize outside of
the shardmap_context.
Tip
If the shards of the matrix cannot be padded with tiles of size T_A
(N / num_gpus % T_A != 0) we have to add padding to fit the last tile.
This requires copying the matrix, which we want to avoid at all costs for
large N. Make sure you pick T_A large enough (>=128) and such that it
can evenly cover the shards. In principle, increasing T_A will increase
performance at the cost of memory, but depending on N, the performance
will saturate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Array
|
Local, row-sharded slice of the global matrix with shape
|
required |
T_A
|
int
|
Tile width used by the native solver. Each
local shard length must be a multiple of |
required |
pad
|
bool
|
If True (default) apply per-device padding to
|
True
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
jax.Array | tuple[jax.Array, int]
|
|
Raises:
| Type | Description |
|---|---|
AssertionError
|
If |
ValueError
|
If |
Notes
- This function does not create an outer
jax.shard_mapor applydonate_argnums; it is intended for use when the caller already controls sharding and device placement. - Padding is handled with :func:
calculate_padding, :func:pad_rows, and :func:unpad_rows. - If the native solver fails the output may contain NaNs and the
returned
statuswill be non-zero.