Skip to content

jaxmg.syevd¤

jaxmg.syevd(a, T_A, mesh, in_specs, return_eigenvectors=True, return_status=False, pad=True) ¤

Compute eigenvalues (and optionally eigenvectors) of a symmetric matrix via the multi-GPU syevd kernel.

Prepares the input and executes the appropriate native cuSolverMg kernel (syevd_mg when eigenvectors are requested or syevd_no_V_mg when not) via jax.ffi.ffi_call under jax.jit and jax.shard_map. Handles per-device padding driven by T_A and returns eigenvalues and, optionally, eigenvectors and a host-side status.

Tip

If the shards of the matrix cannot be padded with tiles of size T_A (N / num_gpus % T_A != 0) we have to add padding to fit the last tile. This requires copying the matrix, which we want to avoid at all costs for large N. Make sure you pick T_A large enough (>=128) and such that it can evenly cover the shards. In principle, increasing T_A will increase performance at the cost of memory, but depending on N, the performance will saturate.

Parameters:

Name Type Description Default
a Array

A 2D JAX array of shape (N_rows, N). Must be symmetric and is expected to be sharded across the mesh along the first (row) axis using P(<axis_name>, None).

required
T_A int

Tile width used by the native solver. Each local shard length must be a multiple of T_A. If the user provides a T_A that is incompatible with the shard size we pad the matrix accordingly. For small tile sizes (T_A< 128), the solver can be extremely slow, so ensure that T_A is large enough. The Cusolver implementation enforces an upper bound of T_A <= 1024.

required
mesh Mesh

JAX device mesh used for jax.shard_map.

required
in_specs PartitionSpec or tuple / list[PartitionSpec]

PartitionSpec describing the input sharding (row sharding). May be provided as a single PartitionSpec or a single-element container containing one.

required
return_eigenvectors bool

If True (default) compute and return eigenvectors alongside eigenvalues. Eigenvectors are returned row-sharded to the same layout as the input and will be unpadded if padding was applied.

True
return_status bool

If True append a host-replicated int32 solver status to the return values. Default is False.

False
pad bool

If True (default) apply per-device padding to meet T_A requirements; if False the caller must supply already- correct shapes.

True

Returns:

Type Description
jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, int] | tuple[jax.Array, int]

Depending on return_eigenvectors and return_status, one of: - eigenvalues (Array of shape (N,)) - (eigenvalues, eigenvectors) - (eigenvalues, status) - (eigenvalues, eigenvectors, status)

Raises:

Type Description
TypeError

If in_specs is not a PartitionSpec or a single- element container.

ValueError

If in_specs does not indicate row sharding (P(<axis_name>, None)) or if T_A exceeds implementation limits.

AssertionError

If a is not 2D or if shape requirements are violated when pad=False.

Notes
  • Eigenvectors (when requested) are returned in the same row sharding as the input.
  • The FFI call can donate the input buffer (donate_argnums=0) to enable zero-copy interaction with the native library.
  • If the native solver fails the outputs may contain NaNs and the status (when requested) will be non-zero.

jaxmg.syevd_shardmap_ctx(a, T_A, return_eigenvectors=True, pad=True) ¤

Compute eigenvalues (and optionally eigenvectors) for row-sharded inputs without shard_map wiring.

This helper is a lightweight, lower-level variant of :func:syevd intended for contexts where the input a is already laid out and sharded at the application level (for example when running inside a custom shard_map/pjit-managed context). It performs the same padding logic driven by T_A and directly calls the native syevd_mg / syevd_no_V_mg FFI targets via jax.ffi.ffi_call instead of constructing an additional shard_map wrapper.

Tip

If the shards of the matrix cannot be padded with tiles of size T_A (N / num_gpus % T_A != 0) we have to add padding to fit the last tile. This requires copying the matrix, which we want to avoid at all costs for large N. Make sure you pick T_A large enough (>=128) and such that it can evenly cover the shards. In principle, increasing T_A will increase performance at the cost of memory, but depending on N, the performance will saturate.

Parameters:

Name Type Description Default
a Array

2D JAX array representing the local, row-sharded slice of the global matrix. Shape should be (shard_size, N) where shard_size is the per-device (local) row count and N is the global matrix dimension.

required
T_A int

Tile width used by the native solver. Each local shard length must be a multiple of T_A. If the user provides a T_A that is incompatible with the shard size we pad the matrix accordingly. For small tile sizes (T_A< 128), the solver can be extremely slow, so ensure that T_A is large enough. The Cusolver implementation enforces an upper bound of T_A <= 1024.

required
return_eigenvectors bool

If True (default) compute and return eigenvectors in addition to eigenvalues. When True the returned eigenvector array has the same local/sharded shape as the input (and will be unpadded if padding was applied).

True
pad bool

If True (default) apply per-device padding to a to satisfy T_A. If False the caller must ensure the provided local shape already meets kernel requirements.

True

Returns:

Type Description
jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, int] | tuple[jax.Array, int]

One of the following, depending on return_eigenvectors and whether

jax.Array | tuple[jax.Array, jax.Array] | tuple[jax.Array, jax.Array, int] | tuple[jax.Array, int]

the caller requests status: - eigenvalues (Array, shape (N,)) - (eigenvalues, eigenvectors) - (eigenvalues, status) - (eigenvalues, eigenvectors, status)

Raises:

Type Description
AssertionError

If a is not a 2D array.

ValueError

If T_A exceeds implementation limits.

Notes
  • This function does not create a jax.shard_map wrapper and does not set donate_argnums; it is intended for use when the caller already controls sharding/device placement.
  • Padding is handled via :func:calculate_padding, :func:pad_rows, and :func:unpad_rows (the latter two are used as local callables rather than shard_map-wrapped functions).
  • If the native solver fails the outputs may contain NaNs and the returned status (when present) will be non-zero.