Skip to content

JAXMg¤

Title Title

JAXMg provides a C++ interface between JAX and cuSolverMg, NVIDIA’s multi-GPU linear solver. We provide a jittable API for the following routines.

  • cusolverMgPotrs: Solves the system of linear equations: \(Ax=b\) where \(A\) is an \(N\times N\) symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition
  • cusolverMgPotrs: Computes the inverse of an \(N\times N\) symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition.
  • cusolverMgPotrs: Computes eigenvalues and eigenvectors of an \(N\times N\) symmetric (Hermitian) matrix.

For more details, see the API.

The provided binaries are compiled with:

Component Version
GCC 11.5.0
CUDA 12.8.0
cuDNN 9.2.0.82-12

Warning

We require JAX>=0.6.0, since it ships with CUDA 12.x binaries, which this package relies on. No local version of CUDA is required.