JAXMg¤
JAXMg provides a C++ interface between JAX and cuSolverMg, NVIDIA’s multi-GPU linear solver. We provide a jittable API for the following routines.
- cusolverMgPotrs: Solves the system of linear equations: \(Ax=b\) where \(A\) is an \(N\times N\) symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition
- cusolverMgPotrs: Computes the inverse of an \(N\times N\) symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition.
- cusolverMgPotrs: Computes eigenvalues and eigenvectors of an \(N\times N\) symmetric (Hermitian) matrix.
For more details, see the API.
The provided binaries are compiled with:
| Component | Version |
|---|---|
| GCC | 11.5.0 |
| CUDA | 12.8.0 |
| cuDNN | 9.2.0.82-12 |
Warning
We require JAX>=0.6.0, since it ships with CUDA 12.x binaries, which this package relies on. No local version of CUDA is required.