"""Implementations of the corresponding Autograd functions"""importfunctoolsimportwarningsfromtypingimportAny,Callable,Dict,Optional,Tuple,Unionimporttorchtry:importfinufftFINUFFT_AVAIL=TrueexceptImportError:FINUFFT_AVAIL=Falsetry:importcufinufftifcufinufft.__version__.startswith("1."):warnings.warn("pytorch-finufft does not support cufinufft v1.x.x")else:CUFINUFFT_AVAIL=TrueexceptImportError:CUFINUFFT_AVAIL=Falseifnot(FINUFFT_AVAILorCUFINUFFT_AVAIL):raiseImportError("No FINUFFT implementation available. ""Install either finufft or cufinufft and ensure they are importable.")importpytorch_finufft.checksaschecksnewaxis=Nonedefget_nufft_func(dim:int,nufft_type:int,device:torch.device)->Callable[...,torch.Tensor]:ifdevice.type=="cuda":ifnotCUFINUFFT_AVAIL:raiseRuntimeError("CUDA device requested but cufinufft failed to import")# note: in the future, cufinufft may figure out gpu_device_id on its own# see: https://github.com/flatironinstitute/finufft/issues/420returnfunctools.partial(getattr(cufinufft,f"nufft{dim}d{nufft_type}"),gpu_device_id=device.index)ifnotFINUFFT_AVAIL:raiseRuntimeError("CPU device requested but finufft failed to import")# CPU needs extra work to go to/from torch and numpyfinufft_func=getattr(finufft,f"nufft{dim}d{nufft_type}")deff(*args,**kwargs):new_args=[argforarginargs]foriinrange(len(new_args)):ifisinstance(new_args[i],torch.Tensor):new_args[i]=new_args[i].data.numpy()returntorch.from_numpy(finufft_func(*new_args,**kwargs))returnfdefcoordinate_ramps(shape,device):start_points=-(torch.tensor(shape,device=device)//2)end_points=start_points+torch.tensor(shape,device=device)coord_ramps=torch.stack(torch.meshgrid(*(torch.arange(start,end,device=device)forstart,endinzip(start_points,end_points)),indexing="ij",))returncoord_ramps[newaxis]defbatch_fftshift(x:torch.Tensor,n_shifted_dims:int)->torch.Tensor:"""fftshift only over the final n_shifted_dims dimensions"""out:torch.Tensor=torch.fft.fftshift(x,dim=tuple(range(-n_shifted_dims,0)))returnoutdefbatch_ifftshift(x:torch.Tensor,n_shifted_dims:int)->torch.Tensor:"""ifftshift only over the final n_shifted_dims dimensions"""out:torch.Tensor=torch.fft.ifftshift(x,dim=tuple(range(-n_shifted_dims,0)))returnoutclassFinufftType1(torch.autograd.Function):""" FINUFFT problem type 1 """ISIGN_DEFAULT=-1# note: FINUFFT default is 1MODEORD_DEFAULT=1# note: FINUFFT default is 0@staticmethoddefsetup_context(ctx:Any,inputs:Tuple[torch.Tensor,torch.Tensor,Any,Optional[Dict[str,Union[int,float]]]],output:Any,)->None:points,values,_,finufftkwargs=inputsctx.save_for_backward(points,values)iffinufftkwargsisNone:finufftkwargs={}else:# copy to avoid mutating caller's dictionaryfinufftkwargs=finufftkwargs.copy()ctx.isign=finufftkwargs.pop("isign",FinufftType1.ISIGN_DEFAULT)ctx.mode_ordering=finufftkwargs.pop("modeord",FinufftType1.MODEORD_DEFAULT)ctx.finufftkwargs=finufftkwargs@staticmethoddefforward(# type: ignorepoints:torch.Tensor,values:torch.Tensor,output_shape:Union[int,Tuple[int],Tuple[int,int],Tuple[int,int,int]],finufftkwargs:Optional[Dict[str,Union[int,float]]]=None,)->torch.Tensor:checks.check_devices(values,points)checks.check_dtypes(values,points,"Values")checks.check_sizes_t1(values,points)points=torch.atleast_2d(points)ndim=points.shape[0]checks.check_output_shape(ndim,output_shape)iffinufftkwargsisNone:finufftkwargs=dict()else:# copy to avoid mutating caller's dictionaryfinufftkwargs=finufftkwargs.copy()finufftkwargs.setdefault("isign",FinufftType1.ISIGN_DEFAULT)# pop because cufinufft doesn't support modeordmodeord=finufftkwargs.pop("modeord",FinufftType1.MODEORD_DEFAULT)nufft_func=get_nufft_func(ndim,1,points.device)batch_dims=values.shape[:-1]finufft_out=nufft_func(*points,values.reshape(-1,values.shape[-1]),output_shape,**finufftkwargs,)finufft_out=finufft_out.reshape(*batch_dims,*output_shape)ifmodeord:finufft_out=batch_ifftshift(finufft_out,ndim)returnfinufft_out@staticmethoddefvmap(# type: ignore[override]info:Any,in_dims:Tuple[Optional[int],...],points:torch.Tensor,values:torch.Tensor,output_shape:Union[int,Tuple[int],Tuple[int,int],Tuple[int,int,int]],finufftkwargs:Optional[Dict[str,Union[int,float]]]=None,)->Tuple[torch.Tensor,int]:batch_points,batch_values,*_=in_dimsifbatch_valuesisnotNone:values=values.movedim(batch_values,0)ifbatch_pointsisnotNone:# need a for-loop herepoints=points.movedim(batch_points,0)ifbatch_valuesisnotNone:output=torch.stack([FinufftType1.apply(points[i],values[i],output_shape,finufftkwargs,)foriinrange(info.batch_size)],dim=0,)else:output=torch.stack([FinufftType1.apply(points[i],values,output_shape,finufftkwargs,)foriinrange(info.batch_size)],dim=0,)else:output=FinufftType1.apply(points,values,output_shape,finufftkwargs)returnoutput,0@staticmethoddefbackward(# type: ignore[override]ctx:Any,grad_output:torch.Tensor)->Tuple[Union[torch.Tensor,None],...]:_i_sign=-1*ctx.isign_mode_ordering=ctx.mode_orderingfinufftkwargs=ctx.finufftkwargspoints,values=ctx.saved_tensorspoints=torch.atleast_2d(points)device=points.devicendim=points.shape[0]grads_points=Nonegrad_values=Nonenufft_func=get_nufft_func(ndim,2,device)ifany(ctx.needs_input_grad):if_mode_ordering:grad_output=batch_fftshift(grad_output,ndim)# group together batched dimensions, if anyshape=grad_output.shape[-ndim:]batch_dims=grad_output.shape[:-ndim]batched_grad_output=grad_output.reshape(-1,1,*shape)nbatch=batched_grad_output.shape[0]ifctx.needs_input_grad[0]:# wrt pointscoord_ramps=coordinate_ramps(shape,device)# nbatch x ndims x ...batched_values=values.reshape(nbatch,1,values.shape[-1])ramped_grad_output=(coord_ramps*batched_grad_output*1j*_i_sign).reshape(-1,*shape)backprop_ramp=(nufft_func(*points,ramped_grad_output,isign=_i_sign,**finufftkwargs).conj().reshape(nbatch,ndim,-1))grads_points=(backprop_ramp*batched_values).real.sum(dim=0)ifctx.needs_input_grad[1]:grad_values=nufft_func(*points,batched_grad_output.squeeze(),isign=_i_sign,**finufftkwargs,).reshape(*batch_dims,-1)return(grads_points,grad_values,None,None,None,None,)classFinufftType2(torch.autograd.Function):""" FINUFFT problem type 2 """ISIGN_DEFAULT=-1# note: FINUFFT default is -1MODEORD_DEFAULT=1# note: FINUFFT default is 0@staticmethoddefsetup_context(ctx:Any,inputs:Tuple[torch.Tensor,torch.Tensor,Optional[Dict[str,Union[int,float]]]],output:Any,)->None:points,targets,finufftkwargs=inputsiffinufftkwargsisNone:finufftkwargs={}else:# copy to avoid mutating caller's dictionaryfinufftkwargs=finufftkwargs.copy()ctx.save_for_backward(points,targets)ctx.isign=finufftkwargs.pop("isign",FinufftType2.ISIGN_DEFAULT)ctx.mode_ordering=finufftkwargs.pop("modeord",FinufftType2.MODEORD_DEFAULT)ctx.finufftkwargs=finufftkwargs@staticmethoddefforward(# type: ignorepoints:torch.Tensor,targets:torch.Tensor,finufftkwargs:Optional[Dict[str,Union[int,float]]]=None,)->torch.Tensor:checks.check_devices(targets,points)checks.check_dtypes(targets,points,"Targets")checks.check_sizes_t2(targets,points)iffinufftkwargsisNone:finufftkwargs=dict()else:finufftkwargs=finufftkwargs.copy()finufftkwargs.setdefault("isign",FinufftType2.ISIGN_DEFAULT)modeord=finufftkwargs.pop("modeord",FinufftType2.MODEORD_DEFAULT)points=torch.atleast_2d(points)ndim=points.shape[0]npoints=points.shape[1]ifmodeord:targets=batch_fftshift(targets,ndim)nufft_func=get_nufft_func(ndim,2,points.device)batch_dims=targets.shape[:-ndim]shape=targets.shape[-ndim:]finufft_out=nufft_func(*points,targets.reshape(-1,*shape),**finufftkwargs,)finufft_out=finufft_out.reshape(*batch_dims,npoints)returnfinufft_out@staticmethoddefvmap(# type: ignore[override]info:Any,in_dims:Tuple[Optional[int],...],points:torch.Tensor,targets:torch.Tensor,finufftkwargs:Optional[Dict[str,Union[int,float]]]=None,)->Tuple[torch.Tensor,int]:batch_points,batch_targets,*_=in_dimsifbatch_targetsisnotNone:targets=targets.movedim(batch_targets,0)ifbatch_pointsisnotNone:# need a for-loop here# potential opportunity for CUDA streamspoints=points.movedim(batch_points,0)ifbatch_targetsisnotNone:output=torch.stack([FinufftType2.apply(points[i],targets[i],# inner productfinufftkwargs,)foriinrange(info.batch_size)],dim=0,)else:output=torch.stack([FinufftType2.apply(points[i],targets,finufftkwargs,)foriinrange(info.batch_size)],dim=0,)else:output=FinufftType2.apply(points,targets,finufftkwargs)returnoutput,0@staticmethoddefbackward(# type: ignore[override]ctx:Any,grad_output:torch.Tensor)->Tuple[Union[torch.Tensor,None],Union[torch.Tensor,None],None,None,None,]:_i_sign=ctx.isign_mode_ordering=ctx.mode_orderingfinufftkwargs=ctx.finufftkwargspoints,targets=ctx.saved_tensorspoints=torch.atleast_2d(points)device=points.devicendim=points.shape[0]grad_points=Nonegrad_targets=Noneifany(ctx.needs_input_grad):if_mode_ordering:# TODO this was also computed in forwardtargets=batch_fftshift(targets,ndim)batch_dims=targets.shape[:-ndim]shape=targets.shape[-ndim:]batched_targets=targets.reshape(-1,1,*shape)nbatch=batched_targets.shape[0]batched_outputs=grad_output.reshape(nbatch,1,grad_output.shape[-1])ifctx.needs_input_grad[0]:# wrt. pointsnufft_func=get_nufft_func(ndim,2,points.device)coord_ramps=coordinate_ramps(shape,device)ramped_targets=(coord_ramps*batched_targets*1j*_i_sign).reshape(-1,*shape)backprop_ramp=(nufft_func(*points,ramped_targets,isign=_i_sign,**finufftkwargs).conj()# Why can't this `conj` be replaced with a flipped isign.reshape(nbatch,ndim,-1))grad_points=(backprop_ramp*batched_outputs).real.sum(dim=0)ifctx.needs_input_grad[1]:# wrt. targetsnufft_func=get_nufft_func(ndim,1,points.device)grad_targets=nufft_func(*points,batched_outputs.squeeze(),shape,isign=-_i_sign,**finufftkwargs,).reshape(*batch_dims,*shape)if_mode_ordering:grad_targets=batch_ifftshift(grad_targets,ndim)return(grad_points,grad_targets,None,None,None,)
[docs]deffinufft_type1(points:torch.Tensor,values:torch.Tensor,output_shape:Union[int,Tuple[int],Tuple[int,int],Tuple[int,int,int]],**finufftkwargs:Union[int,float],)->torch.Tensor:""" Evaluates the Type 1 (nonuniform-to-uniform) NUFFT on the inputs. This is a wrapper around :func:`finufft.nufft1d1`, :func:`finufft.nufft2d1`, and :func:`finufft.nufft3d1` on CPU, and :func:`cufinufft.nufft1d1`, :func:`cufinufft.nufft2d1`, and :func:`cufinufft.nufft3d1` on GPU. Parameters ---------- points : torch.Tensor DxN tensor of locations of the non-uniform points. Points should lie in the range ``[-pi, pi]``, values outside will be folded. values : torch.Tensor Complex-valued tensor of values at the non-uniform points. All dimensions except the final dimension are treated as batch dimensions. The final dimension must have size ``N``. output_shape : int | tuple(int, ...) Requested output shape of Fourier modes. Must be a tuple of length D or an integer (1D only). **finufftkwargs : int | float Additional keyword arguments are forwarded to the underlying FINUFFT functions. A few notable options are - ``eps``: precision requested (default: ``1e-6``) - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) Returns ------- torch.Tensor Tensor with shape ``*[batch], *output_shape`` containing the Fourier transform of the values. """res:torch.Tensor=FinufftType1.apply(points,values,output_shape,finufftkwargs)returnres
[docs]deffinufft_type2(points:torch.Tensor,targets:torch.Tensor,**finufftkwargs:Union[int,float],)->torch.Tensor:""" Evaluates the Type 2 (uniform-to-nonuniform) NUFFT on the inputs. This is a wrapper around :func:`finufft.nufft1d2`, :func:`finufft.nufft2d2`, and :func:`finufft.nufft3d2` on CPU, and :func:`cufinufft.nufft1d2`, :func:`cufinufft.nufft2d2`, and :func:`cufinufft.nufft3d2` on GPU. Parameters ---------- points : torch.Tensor DxN tensor of locations of the non-uniform points. Points should lie in the range ``[-pi, pi]``, values outside will be folded targets : torch.Tensor Complex-valued tensor of Fourier modes to evaluate at the points. The final D dimensions must contain the Fourier modes, and any preceding dimensions are treated as batch dimensions. **finufftkwargs : int | float Additional keyword arguments are forwarded to the underlying FINUFFT functions. A few notable options are - ``eps``: precision requested (default: ``1e-6``) - ``modeord``: 0 for FINUFFT default, 1 for Pytorch default (default: ``1``) - ``isign``: Sign of the exponent in the Fourier transform (default: ``-1``) Returns ------- torch.Tensor A ``[batch]xDxN`` tensor of values at the non-uniform points. """res:torch.Tensor=FinufftType2.apply(points,targets,finufftkwargs)returnres