Source code for pytorch_finufft.functional

"""
Implementations of the corresponding Autograd functions
"""

import functools
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch

try:
    import finufft

    FINUFFT_AVAIL = True
except ImportError:
    FINUFFT_AVAIL = False

try:
    import cufinufft

    if cufinufft.__version__.startswith("1."):
        warnings.warn("pytorch-finufft does not support cufinufft v1.x.x")
    else:
        CUFINUFFT_AVAIL = True
except ImportError:
    CUFINUFFT_AVAIL = False

if not (FINUFFT_AVAIL or CUFINUFFT_AVAIL):
    raise ImportError(
        "No FINUFFT implementation available. "
        "Install either finufft or cufinufft and ensure they are importable."
    )

import pytorch_finufft.checks as checks

newaxis = None


def get_nufft_func(
    dim: int, nufft_type: int, device: torch.device
) -> Callable[..., torch.Tensor]:
    if device.type == "cuda":
        if not CUFINUFFT_AVAIL:
            raise RuntimeError("CUDA device requested but cufinufft failed to import")
        # note: in the future, cufinufft may figure out gpu_device_id on its own
        # see: https://github.com/flatironinstitute/finufft/issues/420
        return functools.partial(
            getattr(cufinufft, f"nufft{dim}d{nufft_type}"), gpu_device_id=device.index
        )

    if not FINUFFT_AVAIL:
        raise RuntimeError("CPU device requested but finufft failed to import")
    # CPU needs extra work to go to/from torch and numpy
    finufft_func = getattr(finufft, f"nufft{dim}d{nufft_type}")

    def f(*args, **kwargs):
        new_args = [arg for arg in args]
        for i in range(len(new_args)):
            if isinstance(new_args[i], torch.Tensor):
                new_args[i] = new_args[i].data.numpy()

        return torch.from_numpy(finufft_func(*new_args, **kwargs))

    return f


def coordinate_ramps(shape, device):
    start_points = -(torch.tensor(shape, device=device) // 2)
    end_points = start_points + torch.tensor(shape, device=device)
    coord_ramps = torch.stack(
        torch.meshgrid(
            *(
                torch.arange(start, end, device=device)
                for start, end in zip(start_points, end_points)
            ),
            indexing="ij",
        )
    )

    return coord_ramps[newaxis]


def batch_fftshift(x: torch.Tensor, n_shifted_dims: int) -> torch.Tensor:
    """fftshift only over the final n_shifted_dims dimensions"""
    out: torch.Tensor = torch.fft.fftshift(x, dim=tuple(range(-n_shifted_dims, 0)))
    return out


def batch_ifftshift(x: torch.Tensor, n_shifted_dims: int) -> torch.Tensor:
    """ifftshift only over the final n_shifted_dims dimensions"""
    out: torch.Tensor = torch.fft.ifftshift(x, dim=tuple(range(-n_shifted_dims, 0)))
    return out


class FinufftType1(torch.autograd.Function):
    """
    FINUFFT problem type 1
    """

    ISIGN_DEFAULT = -1  # note: FINUFFT default is 1
    MODEORD_DEFAULT = 1  # note: FINUFFT default is 0

    @staticmethod
    def setup_context(
        ctx: Any,
        inputs: Tuple[
            torch.Tensor, torch.Tensor, Any, Optional[Dict[str, Union[int, float]]]
        ],
        output: Any,
    ) -> None:
        points, values, _, finufftkwargs = inputs
        ctx.save_for_backward(points, values)

        if finufftkwargs is None:
            finufftkwargs = {}
        else:  # copy to avoid mutating caller's dictionary
            finufftkwargs = finufftkwargs.copy()
        ctx.isign = finufftkwargs.pop("isign", FinufftType1.ISIGN_DEFAULT)
        ctx.mode_ordering = finufftkwargs.pop("modeord", FinufftType1.MODEORD_DEFAULT)
        ctx.finufftkwargs = finufftkwargs

    @staticmethod
    def forward(  # type: ignore
        points: torch.Tensor,
        values: torch.Tensor,
        output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
        finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
    ) -> torch.Tensor:
        checks.check_devices(values, points)
        checks.check_dtypes(values, points, "Values")
        checks.check_sizes_t1(values, points)
        points = torch.atleast_2d(points)
        ndim = points.shape[0]
        checks.check_output_shape(ndim, output_shape)

        if finufftkwargs is None:
            finufftkwargs = dict()
        else:  # copy to avoid mutating caller's dictionary
            finufftkwargs = finufftkwargs.copy()

        finufftkwargs.setdefault("isign", FinufftType1.ISIGN_DEFAULT)
        # pop because cufinufft doesn't support modeord
        modeord = finufftkwargs.pop("modeord", FinufftType1.MODEORD_DEFAULT)

        nufft_func = get_nufft_func(ndim, 1, points.device)

        batch_dims = values.shape[:-1]
        finufft_out = nufft_func(
            *points,
            values.reshape(-1, values.shape[-1]),
            output_shape,
            **finufftkwargs,
        )
        finufft_out = finufft_out.reshape(*batch_dims, *output_shape)

        if modeord:
            finufft_out = batch_ifftshift(finufft_out, ndim)

        return finufft_out

    @staticmethod
    def vmap(  # type: ignore[override]
        info: Any,
        in_dims: Tuple[Optional[int], ...],
        points: torch.Tensor,
        values: torch.Tensor,
        output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
        finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
    ) -> Tuple[torch.Tensor, int]:
        batch_points, batch_values, *_ = in_dims
        if batch_values is not None:
            values = values.movedim(batch_values, 0)

        if batch_points is not None:
            # need a for-loop here
            points = points.movedim(batch_points, 0)
            if batch_values is not None:
                output = torch.stack(
                    [
                        FinufftType1.apply(
                            points[i],
                            values[i],
                            output_shape,
                            finufftkwargs,
                        )
                        for i in range(info.batch_size)
                    ],
                    dim=0,
                )
            else:
                output = torch.stack(
                    [
                        FinufftType1.apply(
                            points[i],
                            values,
                            output_shape,
                            finufftkwargs,
                        )
                        for i in range(info.batch_size)
                    ],
                    dim=0,
                )
        else:
            output = FinufftType1.apply(points, values, output_shape, finufftkwargs)

        return output, 0

    @staticmethod
    def backward(  # type: ignore[override]
        ctx: Any, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        _i_sign = -1 * ctx.isign
        _mode_ordering = ctx.mode_ordering
        finufftkwargs = ctx.finufftkwargs

        points, values = ctx.saved_tensors
        points = torch.atleast_2d(points)

        device = points.device
        ndim = points.shape[0]

        grads_points = None
        grad_values = None

        nufft_func = get_nufft_func(ndim, 2, device)

        if any(ctx.needs_input_grad):
            if _mode_ordering:
                grad_output = batch_fftshift(grad_output, ndim)

            # group together batched dimensions, if any
            shape = grad_output.shape[-ndim:]
            batch_dims = grad_output.shape[:-ndim]
            batched_grad_output = grad_output.reshape(-1, 1, *shape)
            nbatch = batched_grad_output.shape[0]

        if ctx.needs_input_grad[0]:
            # wrt points
            coord_ramps = coordinate_ramps(shape, device)

            # nbatch x ndims x ...
            batched_values = values.reshape(nbatch, 1, values.shape[-1])

            ramped_grad_output = (
                coord_ramps * batched_grad_output * 1j * _i_sign
            ).reshape(-1, *shape)

            backprop_ramp = (
                nufft_func(*points, ramped_grad_output, isign=_i_sign, **finufftkwargs)
                .conj()
                .reshape(nbatch, ndim, -1)
            )

            grads_points = (backprop_ramp * batched_values).real.sum(dim=0)

        if ctx.needs_input_grad[1]:
            grad_values = nufft_func(
                *points,
                batched_grad_output.squeeze(),
                isign=_i_sign,
                **finufftkwargs,
            ).reshape(*batch_dims, -1)

        return (
            grads_points,
            grad_values,
            None,
            None,
            None,
            None,
        )


class FinufftType2(torch.autograd.Function):
    """
    FINUFFT problem type 2
    """

    ISIGN_DEFAULT = -1  # note: FINUFFT default is -1
    MODEORD_DEFAULT = 1  # note: FINUFFT default is 0

    @staticmethod
    def setup_context(
        ctx: Any,
        inputs: Tuple[
            torch.Tensor, torch.Tensor, Optional[Dict[str, Union[int, float]]]
        ],
        output: Any,
    ) -> None:
        points, targets, finufftkwargs = inputs
        if finufftkwargs is None:
            finufftkwargs = {}
        else:  # copy to avoid mutating caller's dictionary
            finufftkwargs = finufftkwargs.copy()
        ctx.save_for_backward(points, targets)
        ctx.isign = finufftkwargs.pop("isign", FinufftType2.ISIGN_DEFAULT)
        ctx.mode_ordering = finufftkwargs.pop("modeord", FinufftType2.MODEORD_DEFAULT)
        ctx.finufftkwargs = finufftkwargs

    @staticmethod
    def forward(  # type: ignore
        points: torch.Tensor,
        targets: torch.Tensor,
        finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
    ) -> torch.Tensor:
        checks.check_devices(targets, points)
        checks.check_dtypes(targets, points, "Targets")
        checks.check_sizes_t2(targets, points)

        if finufftkwargs is None:
            finufftkwargs = dict()
        else:
            finufftkwargs = finufftkwargs.copy()

        finufftkwargs.setdefault("isign", FinufftType2.ISIGN_DEFAULT)

        modeord = finufftkwargs.pop("modeord", FinufftType2.MODEORD_DEFAULT)

        points = torch.atleast_2d(points)
        ndim = points.shape[0]
        npoints = points.shape[1]
        if modeord:
            targets = batch_fftshift(targets, ndim)

        nufft_func = get_nufft_func(ndim, 2, points.device)
        batch_dims = targets.shape[:-ndim]
        shape = targets.shape[-ndim:]
        finufft_out = nufft_func(
            *points,
            targets.reshape(-1, *shape),
            **finufftkwargs,
        )
        finufft_out = finufft_out.reshape(*batch_dims, npoints)

        return finufft_out

    @staticmethod
    def vmap(  # type: ignore[override]
        info: Any,
        in_dims: Tuple[Optional[int], ...],
        points: torch.Tensor,
        targets: torch.Tensor,
        finufftkwargs: Optional[Dict[str, Union[int, float]]] = None,
    ) -> Tuple[torch.Tensor, int]:
        batch_points, batch_targets, *_ = in_dims

        if batch_targets is not None:
            targets = targets.movedim(batch_targets, 0)

        if batch_points is not None:
            # need a for-loop here
            # potential opportunity for CUDA streams
            points = points.movedim(batch_points, 0)
            if batch_targets is not None:
                output = torch.stack(
                    [
                        FinufftType2.apply(
                            points[i],
                            targets[i],  # inner product
                            finufftkwargs,
                        )
                        for i in range(info.batch_size)
                    ],
                    dim=0,
                )
            else:
                output = torch.stack(
                    [
                        FinufftType2.apply(
                            points[i],
                            targets,
                            finufftkwargs,
                        )
                        for i in range(info.batch_size)
                    ],
                    dim=0,
                )
        else:
            output = FinufftType2.apply(points, targets, finufftkwargs)

        return output, 0

    @staticmethod
    def backward(  # type: ignore[override]
        ctx: Any, grad_output: torch.Tensor
    ) -> Tuple[
        Union[torch.Tensor, None],
        Union[torch.Tensor, None],
        None,
        None,
        None,
    ]:
        _i_sign = ctx.isign
        _mode_ordering = ctx.mode_ordering
        finufftkwargs = ctx.finufftkwargs

        points, targets = ctx.saved_tensors
        points = torch.atleast_2d(points)
        device = points.device
        ndim = points.shape[0]

        grad_points = None
        grad_targets = None

        if any(ctx.needs_input_grad):
            if _mode_ordering:
                # TODO this was also computed in forward
                targets = batch_fftshift(targets, ndim)

            batch_dims = targets.shape[:-ndim]
            shape = targets.shape[-ndim:]
            batched_targets = targets.reshape(-1, 1, *shape)
            nbatch = batched_targets.shape[0]
            batched_outputs = grad_output.reshape(nbatch, 1, grad_output.shape[-1])

        if ctx.needs_input_grad[0]:
            # wrt. points
            nufft_func = get_nufft_func(ndim, 2, points.device)

            coord_ramps = coordinate_ramps(shape, device)

            ramped_targets = (coord_ramps * batched_targets * 1j * _i_sign).reshape(
                -1, *shape
            )

            backprop_ramp = (
                nufft_func(*points, ramped_targets, isign=_i_sign, **finufftkwargs)
                .conj()  # Why can't this `conj` be replaced with a flipped isign
                .reshape(nbatch, ndim, -1)
            )

            grad_points = (backprop_ramp * batched_outputs).real.sum(dim=0)

        if ctx.needs_input_grad[1]:
            # wrt. targets
            nufft_func = get_nufft_func(ndim, 1, points.device)

            grad_targets = nufft_func(
                *points,
                batched_outputs.squeeze(),
                shape,
                isign=-_i_sign,
                **finufftkwargs,
            ).reshape(*batch_dims, *shape)

            if _mode_ordering:
                grad_targets = batch_ifftshift(grad_targets, ndim)

        return (
            grad_points,
            grad_targets,
            None,
            None,
            None,
        )


[docs] def finufft_type1( points: torch.Tensor, values: torch.Tensor, output_shape: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]], **finufftkwargs: Union[int, float], ) -> torch.Tensor: """ Evaluates the Type 1 (nonuniform-to-uniform) NUFFT on the inputs. This is a wrapper around :func:`finufft.nufft1d1`, :func:`finufft.nufft2d1`, and :func:`finufft.nufft3d1` on CPU, and :func:`cufinufft.nufft1d1`, :func:`cufinufft.nufft2d1`, and :func:`cufinufft.nufft3d1` on GPU. Parameters ---------- points : torch.Tensor DxN tensor of locations of the non-uniform points. Points should lie in the range ``[-pi, pi]``, values outside will be folded. values : torch.Tensor Complex-valued tensor of values at the non-uniform points. All dimensions except the final dimension are treated as batch dimensions. The final dimension must have size ``N``. output_shape : int | tuple(int, ...) Requested output shape of Fourier modes. Must be a tuple of length D or an integer (1D only). **finufftkwargs : int | float Additional keyword arguments are forwarded to the underlying FINUFFT functions. A few notable options are - ``eps``: precision requested (default: ``1e-6``) - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) Returns ------- torch.Tensor Tensor with shape ``*[batch], *output_shape`` containing the Fourier transform of the values. """ res: torch.Tensor = FinufftType1.apply(points, values, output_shape, finufftkwargs) return res
[docs] def finufft_type2( points: torch.Tensor, targets: torch.Tensor, **finufftkwargs: Union[int, float], ) -> torch.Tensor: """ Evaluates the Type 2 (uniform-to-nonuniform) NUFFT on the inputs. This is a wrapper around :func:`finufft.nufft1d2`, :func:`finufft.nufft2d2`, and :func:`finufft.nufft3d2` on CPU, and :func:`cufinufft.nufft1d2`, :func:`cufinufft.nufft2d2`, and :func:`cufinufft.nufft3d2` on GPU. Parameters ---------- points : torch.Tensor DxN tensor of locations of the non-uniform points. Points should lie in the range ``[-pi, pi]``, values outside will be folded targets : torch.Tensor Complex-valued tensor of Fourier modes to evaluate at the points. The final D dimensions must contain the Fourier modes, and any preceding dimensions are treated as batch dimensions. **finufftkwargs : int | float Additional keyword arguments are forwarded to the underlying FINUFFT functions. A few notable options are - ``eps``: precision requested (default: ``1e-6``) - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) Returns ------- torch.Tensor A ``[batch]xDxN`` tensor of values at the non-uniform points. """ res: torch.Tensor = FinufftType2.apply(points, targets, finufftkwargs) return res