jaxmg.syevd¤
jaxmg.syevd(a, T_A, mesh, in_specs, return_eigenvectors=True, return_status=False, pad=True)
¤
Compute eigenvalues (and optionally eigenvectors) of a symmetric matrix via the multi-GPU syevd kernel.
Prepares the input and executes the appropriate native cuSolverMg kernel
(syevd_mg when eigenvectors are requested or syevd_no_V_mg when
not) via jax.ffi.ffi_call under jax.jit and jax.shard_map. Handles
per-device padding driven by T_A and returns eigenvalues and, optionally,
eigenvectors and a host-side status.
Tip
If the shards of the matrix cannot be padded with tiles of size T_A
(N / num_gpus % T_A != 0) we have to add padding to fit the last tile.
This requires copying the matrix, which we want to avoid at all costs for
large N. Make sure you pick T_A large enough (>=128) and such that it
can evenly cover the shards. In principle, increasing T_A will increase
performance at the cost of memory, but depending on N, the performance
will saturate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Array
|
A 2D JAX array of shape |
required |
T_A
|
int
|
Tile width used by the native solver. Each
local shard length must be a multiple of |
required |
mesh
|
Mesh
|
JAX device mesh used for |
required |
in_specs
|
PartitionSpec or tuple / list[PartitionSpec]
|
PartitionSpec
describing the input sharding (row sharding). May be provided as a
single |
required |
return_eigenvectors
|
bool
|
If True (default) compute and return eigenvectors alongside eigenvalues. Eigenvectors are returned row-sharded to the same layout as the input and will be unpadded if padding was applied. |
True
|
return_status
|
bool
|
If True append a host-replicated int32 solver status to the return values. Default is False. |
False
|
pad
|
bool
|
If True (default) apply per-device padding to meet
|
True
|
Returns:
| Type | Description |
|---|---|
jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, int] | tuple[jax.Array, int]
|
Depending on |
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
ValueError
|
If |
AssertionError
|
If |
Notes
- Eigenvectors (when requested) are returned in the same row sharding as the input.
- The FFI call can donate the input buffer (
donate_argnums=0) to enable zero-copy interaction with the native library. - If the native solver fails the outputs may contain NaNs and the status (when requested) will be non-zero.
jaxmg.syevd_shardmap_ctx(a, T_A, return_eigenvectors=True, pad=True)
¤
Compute eigenvalues (and optionally eigenvectors) for row-sharded inputs without shard_map wiring.
This helper is a lightweight, lower-level variant of :func:syevd intended
for contexts where the input a is already laid out and sharded at the
application level (for example when running inside a custom
shard_map/pjit-managed context). It performs the same padding logic
driven by T_A and directly calls the native syevd_mg /
syevd_no_V_mg FFI targets via jax.ffi.ffi_call instead of
constructing an additional shard_map wrapper.
Tip
If the shards of the matrix cannot be padded with tiles of size T_A
(N / num_gpus % T_A != 0) we have to add padding to fit the last tile.
This requires copying the matrix, which we want to avoid at all costs for
large N. Make sure you pick T_A large enough (>=128) and such that it
can evenly cover the shards. In principle, increasing T_A will increase
performance at the cost of memory, but depending on N, the performance
will saturate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Array
|
2D JAX array representing the local, row-sharded slice of
the global matrix. Shape should be |
required |
T_A
|
int
|
Tile width used by the native solver. Each
local shard length must be a multiple of |
required |
return_eigenvectors
|
bool
|
If True (default) compute and return eigenvectors in addition to eigenvalues. When True the returned eigenvector array has the same local/sharded shape as the input (and will be unpadded if padding was applied). |
True
|
pad
|
bool
|
If True (default) apply per-device padding to
|
True
|
Returns:
| Type | Description |
|---|---|
jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, int] | tuple[jax.Array, int]
|
One of the following, depending on |
jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, int] | tuple[jax.Array, int]
|
the caller requests status:
- |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If |
ValueError
|
If |
Notes
- This function does not create a
jax.shard_mapwrapper and does not setdonate_argnums; it is intended for use when the caller already controls sharding/device placement. - Padding is handled via :func:
calculate_padding, :func:pad_rows, and :func:unpad_rows(the latter two are used as local callables rather than shard_map-wrapped functions). - If the native solver fails the outputs may contain NaNs and the
returned
status(when present) will be non-zero.