jaxmg.potrs¤
jaxmg.potrs(a, b, T_A, mesh, in_specs, return_status=False, pad=True)
¤
Solve the linear system A x = B using the multi-GPU potrs native kernel.
Prepares inputs for the native potrs_mg kernel and executes it via
jax.ffi.ffi_call under jax.jit and jax.shard_map. Handles
per-device padding driven by T_A and returns the solution (and
optionally a host-side solver status).
Tip
If the shards of the matrix cannot be padded with tiles of size T_A
(N / num_gpus % T_A != 0) we have to add padding to fit the last tile.
This requires copying the matrix, which we want to avoid at all costs for
large N. Make sure you pick T_A large enough (>=128) and such that it
can evenly cover the shards. In principle, increasing T_A will increase
performance at the cost of memory, but depending on N, the performance
will saturate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Array
|
2D, symmetric matrix representing the coefficient matrix.
Expected to be sharded across the mesh along the first (row) axis
using a single |
required |
b
|
Array
|
2D right-hand side. Expected to be replicated across
devices with |
required |
T_A
|
int
|
Tile width used by the native solver. Each
local shard length must be a multiple of |
required |
mesh
|
Mesh
|
JAX Mesh object used for |
required |
in_specs
|
tuple[list][PartitionSpec]
|
The sharding specifications for
|
required |
return_status
|
bool
|
If True return |
False
|
pad
|
bool
|
If True (default) apply per-device padding to
|
True
|
Returns:
| Type | Description |
|---|---|
jax.Array | tuple[jax.Array, int]
|
Array or (Array, int): The solution |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If |
ValueError
|
If |
Notes
- The FFI call may donate the
abuffer (donate_argnums=0) for zero-copy interaction with the native library. - If the native solver fails the returned solution may contain NaNs and
statuswill be non-zero.
jaxmg.potrs_shardmap_ctx(a, b, T_A, pad=True)
¤
Solve A x = B by invoking the native multi-GPU potrs kernel without shard_map.
This helper is a lightweight, lower-level variant of :func:jaxmg.potrs intended
for contexts where the input a is already laid out and sharded at the
application level (for example when running inside a custom
shard_map/pjit-managed context). It performs the same padding logic
driven by T_A and directly calls the native potrs_mg FFI targets
via jax.ffi.ffi_call instead of constructing an additional shard_map
wrapper.
Tip
If the shards of the matrix cannot be padded with tiles of size T_A
(N / num_gpus % T_A != 0) we have to add padding to fit the last tile.
This requires copying the matrix, which we want to avoid at all costs for
large N. Make sure you pick T_A large enough (>=128) and such that it
can evenly cover the shards. In principle, increasing T_A will increase
performance at the cost of memory, but depending on N, the performance
will saturate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Array
|
2D coefficient matrix of shape |
required |
b
|
Array
|
2D right-hand side. Its first dimension must equal the
number of columns of |
required |
T_A
|
int
|
Tile width used by the native solver. Each
local shard length must be a multiple of |
required |
pad
|
bool
|
If True (default) apply per-device padding to
|
True
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[jax.Array, jax.Array]
|
|
Raises:
| Type | Description |
|---|---|
AssertionError
|
If input arrays are not 2D or their shapes are incompatible. |
Notes
- This function does not perform sharding via
jax.shard_mapand therefore must be called only in a shard_map context. - Because it does not use
donate_argnums, the input buffers are not donated to the FFI call (no zero-copy donation semantics).