Installationยค
Clone the repository and install with:
pip install jaxmg
This will install a GPU compatible version of JAX.
To verify the installation (requires at least one GPU) run
pytest
- SPMD tests: Single Process Multiple GPU tests.
- MPMD: Multiple Processes Multiple GPU tests.