Skip to content

Multiple Process Multiple Devices (MPMD)ยค

In a multi-process context it is not as straightforward to setup memory sharing between processes, especially when it comes to passing around device pointers which are bound to a specific CUDA context.

The solution used here is to make use of the cudaIPC documentation, which allows one to export handles to device memory to different processes. In potrs_mp.cu, we achieve this again through shared memory, although now we share the cudaIPC memory handles:

ipcGetHandleAndOffset(array_data_A, 
                      shmAipc[currentDevice], 
                      shmoffsetA[currentDevice]);

A significant complication is that JAX' memory allocation is managed by XLA, which means that device pointers are actually base pointers together with some offset. cudaIPC only exports the base-pointer, so we have to manually pass around the offset and extract the true pointer:

opened_ptrs_A = ipcGetDevicePointers<data_type>(currentDevice, 
                                                nbGpus,
                                                shmAipc, 
                                                shmoffsetA);

We gather all the pointers in process 0 and set up the solver in the same way as before. After completion, it is essential to close the memory handles

ipcCloseDevicePointers(currentDevice, 
                       opened_ptrs_A.bases, 
                       nbGpus);

to avoid memory leaks.

Note: If you've made it this far and have experience or thoughts on this, please reach out!