.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/convolution_2d.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_convolution_2d.py: Convolution in 2D ================= .. GENERATED FROM PYTHON SOURCE LINES 7-11 Import packages --------------- First, we import the packages we need for this example. .. GENERATED FROM PYTHON SOURCE LINES 11-18 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np import torch import pytorch_finufft .. GENERATED FROM PYTHON SOURCE LINES 19-20 Let's create a Gaussian convolutional filter as a function of x,y .. GENERATED FROM PYTHON SOURCE LINES 20-26 .. code-block:: Python def gaussian_function(x, y, sigma=1): return np.exp(-(x**2 + y**2) / (2 * sigma**2)) .. GENERATED FROM PYTHON SOURCE LINES 27-29 Let's visualize this filter kernel. We will be using it to convolve with points living on the $[0, 2*\pi] \times [0, 2*\pi]$ torus. So let's dimension it accordingly. .. GENERATED FROM PYTHON SOURCE LINES 29-40 .. code-block:: Python shape = (128, 128) sigma = 0.5 x = np.linspace(-np.pi, np.pi, shape[0], endpoint=False) y = np.linspace(-np.pi, np.pi, shape[1], endpoint=False) gaussian_kernel = gaussian_function(x[:, np.newaxis], y, sigma=sigma) fig, ax = plt.subplots() _ = ax.imshow(gaussian_kernel) .. image-sg:: /examples/images/sphx_glr_convolution_2d_001.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 41-43 In order for the kernel to not shift the signal, we need to place its mass at 0. To do this, we ifftshift the kernel .. GENERATED FROM PYTHON SOURCE LINES 43-50 .. code-block:: Python shifted_gaussian_kernel = np.fft.ifftshift(gaussian_kernel) fig, ax = plt.subplots() _ = ax.imshow(shifted_gaussian_kernel) .. image-sg:: /examples/images/sphx_glr_convolution_2d_002.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 51-52 Now let's create a point cloud on the torus that we can convolve with our filter .. GENERATED FROM PYTHON SOURCE LINES 52-63 .. code-block:: Python N = 20 points = np.random.rand(2, N) * 2 * np.pi fig, ax = plt.subplots() ax.set_xlim(0, 2 * np.pi) ax.set_ylim(0, 2 * np.pi) ax.set_aspect("equal") _ = ax.scatter(points[0], points[1], s=1) .. image-sg:: /examples/images/sphx_glr_convolution_2d_003.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 64-68 Now we can convolve the point cloud with the filter kernel. To do this, we Fourier-transform both the point cloud and the filter kernel, multiply them together, and then inverse Fourier-transform the result. First we need to convert all data to torch tensors .. GENERATED FROM PYTHON SOURCE LINES 68-89 .. code-block:: Python fourier_shifted_gaussian_kernel = torch.fft.fft2( torch.from_numpy(shifted_gaussian_kernel) ) fourier_points = pytorch_finufft.functional.finufft_type1( torch.from_numpy(points), torch.ones(points.shape[1], dtype=torch.complex128), shape ) fig, axs = plt.subplots(1, 3) axs[0].imshow(fourier_shifted_gaussian_kernel.real) axs[1].imshow(fourier_points.real, vmin=-10, vmax=10) _ = axs[2].imshow( ( fourier_points * fourier_shifted_gaussian_kernel / fourier_shifted_gaussian_kernel[0, 0] ).real, vmin=-10, vmax=10, ) .. image-sg:: /examples/images/sphx_glr_convolution_2d_004.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 90-93 We now have two possibilities: Invert the Fourier transform on a grid, or on a point cloud. We'll first invert the Fourier transform on a grid in order to be able to visualize the effect of the convolution. .. GENERATED FROM PYTHON SOURCE LINES 93-102 .. code-block:: Python convolved_points = torch.fft.ifft2(fourier_points * fourier_shifted_gaussian_kernel) fig, ax = plt.subplots() ax.imshow(convolved_points.real) _ = ax.scatter( points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c="r" ) .. image-sg:: /examples/images/sphx_glr_convolution_2d_005.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 103-106 We see that the convolution has smeared out the point cloud. After a small coordinate change, we can also plot the original points on the same plot as the convolved points. .. GENERATED FROM PYTHON SOURCE LINES 109-112 Next, we invert the Fourier transform on the same points as our original point cloud. We will then compare this to direct evaluation of the kernel on all pairwise difference vectors between the points. .. GENERATED FROM PYTHON SOURCE LINES 112-128 .. code-block:: Python convolved_at_points = pytorch_finufft.functional.finufft_type2( torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel, isign=1, ).real / np.prod(shape) fig, ax = plt.subplots() ax.imshow(convolved_points.real) _ = ax.scatter( points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=10 * convolved_at_points, c="r", ) .. image-sg:: /examples/images/sphx_glr_convolution_2d_006.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 129-132 To compute the convolution directly, we need to evaluate the kernel on all pairwise difference vectors between the points. Note the points that will be off the diagonal. These will be due to the periodic boundary conditions of the convolution. .. GENERATED FROM PYTHON SOURCE LINES 132-150 .. code-block:: Python pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis] kernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma) convolved_by_hand = kernel_diff_evals.sum(1) fig, ax = plt.subplots() ax.plot(convolved_at_points.numpy(), convolved_by_hand, ".") ax.plot([1, 3], [1, 3]) relative_difference = torch.norm( convolved_at_points - convolved_by_hand ) / np.linalg.norm(convolved_by_hand) print( "Relative difference between fourier convolution and direct convolution " f"{relative_difference}" ) .. image-sg:: /examples/images/sphx_glr_convolution_2d_007.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_007.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/work/pytorch-finufft/pytorch-finufft/examples/convolution_2d.py:142: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0) convolved_at_points - convolved_by_hand Relative difference between fourier convolution and direct convolution 0.02535164429649562 .. GENERATED FROM PYTHON SOURCE LINES 151-154 Now let's see if we can learn the convolution kernel from the input and output point clouds. To this end, let's first make a pytorch object that can compute a kernel convolution on a point cloud. .. GENERATED FROM PYTHON SOURCE LINES 154-185 .. code-block:: Python class FourierPointConvolution(torch.nn.Module): def __init__(self, fourier_kernel_shape): super().__init__() self.fourier_kernel_shape = fourier_kernel_shape self.build() def build(self): self.register_parameter( "fourier_kernel", torch.nn.Parameter( torch.randn(self.fourier_kernel_shape, dtype=torch.complex128) ), ) # ^ think about whether we need to scale this init in some better way def forward(self, points, values): fourier_transformed_input = pytorch_finufft.functional.finufft_type1( points, values, self.fourier_kernel_shape ) fourier_convolved = fourier_transformed_input * self.fourier_kernel convolved = pytorch_finufft.functional.finufft_type2( points, fourier_convolved, isign=1, ).real / np.prod(self.fourier_kernel_shape) return convolved .. GENERATED FROM PYTHON SOURCE LINES 186-188 Now we can use this object in a pytorch training loop to learn the kernel from the input and output point clouds. We will use the mean squared error as a loss function. .. GENERATED FROM PYTHON SOURCE LINES 188-235 .. code-block:: Python fourier_point_convolution = FourierPointConvolution(shape) optimizer = torch.optim.AdamW( fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001 ) ones = torch.ones(points.shape[1], dtype=torch.complex128) losses = [] for i in range(10000): # Make new set of points and compute forward model points = np.random.rand(2, N) * 2 * np.pi torch_points = torch.from_numpy(points) fourier_points = pytorch_finufft.functional.finufft_type1( torch.from_numpy(points), ones, shape ) convolved_at_points = pytorch_finufft.functional.finufft_type2( torch.from_numpy(points), fourier_points * fourier_shifted_gaussian_kernel, isign=1, ).real / np.prod(shape) # Learning step optimizer.zero_grad() convolved = fourier_point_convolution(torch_points, ones) loss = torch.nn.functional.mse_loss(convolved, convolved_at_points) losses.append(loss.item()) loss.backward() optimizer.step() if i % 100 == 0: print(f"Iteration {i:05d}, Loss: {loss.item():1.4f}") fig, ax = plt.subplots() ax.plot(losses) ax.set_ylabel("Loss") ax.set_xlabel("Iteration") ax.set_yscale("log") fig, ax = plt.subplots() im = ax.imshow( torch.real(torch.fft.fftshift(fourier_point_convolution.fourier_kernel.data))[ 48:80, 48:80 ] ) _ = fig.colorbar(im, ax=ax) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /examples/images/sphx_glr_convolution_2d_008.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_008.png :class: sphx-glr-multi-img * .. image-sg:: /examples/images/sphx_glr_convolution_2d_009.png :alt: convolution 2d :srcset: /examples/images/sphx_glr_convolution_2d_009.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none Iteration 00000, Loss: 3.0631 Iteration 00100, Loss: 2.4286 Iteration 00200, Loss: 1.1522 Iteration 00300, Loss: 1.1646 Iteration 00400, Loss: 2.5064 Iteration 00500, Loss: 1.5273 Iteration 00600, Loss: 0.6250 Iteration 00700, Loss: 0.3812 Iteration 00800, Loss: 0.7302 Iteration 00900, Loss: 0.2382 Iteration 01000, Loss: 0.2265 Iteration 01100, Loss: 0.2691 Iteration 01200, Loss: 0.3496 Iteration 01300, Loss: 0.2065 Iteration 01400, Loss: 0.2455 Iteration 01500, Loss: 0.2089 Iteration 01600, Loss: 0.3406 Iteration 01700, Loss: 0.2549 Iteration 01800, Loss: 0.2508 Iteration 01900, Loss: 0.3798 Iteration 02000, Loss: 0.2318 Iteration 02100, Loss: 0.1519 Iteration 02200, Loss: 0.1855 Iteration 02300, Loss: 0.5736 Iteration 02400, Loss: 0.5329 Iteration 02500, Loss: 0.3214 Iteration 02600, Loss: 0.2928 Iteration 02700, Loss: 0.3313 Iteration 02800, Loss: 0.2698 Iteration 02900, Loss: 0.1951 Iteration 03000, Loss: 0.3175 Iteration 03100, Loss: 0.1815 Iteration 03200, Loss: 0.4317 Iteration 03300, Loss: 0.2332 Iteration 03400, Loss: 0.1141 Iteration 03500, Loss: 0.2738 Iteration 03600, Loss: 0.8844 Iteration 03700, Loss: 0.3819 Iteration 03800, Loss: 0.3370 Iteration 03900, Loss: 0.1624 Iteration 04000, Loss: 0.3616 Iteration 04100, Loss: 0.1929 Iteration 04200, Loss: 0.1788 Iteration 04300, Loss: 0.3214 Iteration 04400, Loss: 0.1094 Iteration 04500, Loss: 0.2928 Iteration 04600, Loss: 0.1789 Iteration 04700, Loss: 0.2060 Iteration 04800, Loss: 0.7645 Iteration 04900, Loss: 0.2288 Iteration 05000, Loss: 0.2417 Iteration 05100, Loss: 0.3978 Iteration 05200, Loss: 0.0938 Iteration 05300, Loss: 0.1879 Iteration 05400, Loss: 0.4689 Iteration 05500, Loss: 0.1722 Iteration 05600, Loss: 0.2085 Iteration 05700, Loss: 0.1926 Iteration 05800, Loss: 0.3311 Iteration 05900, Loss: 0.1670 Iteration 06000, Loss: 0.2194 Iteration 06100, Loss: 0.2344 Iteration 06200, Loss: 0.2325 Iteration 06300, Loss: 0.4041 Iteration 06400, Loss: 0.2492 Iteration 06500, Loss: 0.2746 Iteration 06600, Loss: 0.1514 Iteration 06700, Loss: 0.3190 Iteration 06800, Loss: 0.5207 Iteration 06900, Loss: 0.1167 Iteration 07000, Loss: 0.3730 Iteration 07100, Loss: 0.4862 Iteration 07200, Loss: 0.2018 Iteration 07300, Loss: 0.2163 Iteration 07400, Loss: 0.2765 Iteration 07500, Loss: 0.2773 Iteration 07600, Loss: 0.4174 Iteration 07700, Loss: 0.2125 Iteration 07800, Loss: 0.2311 Iteration 07900, Loss: 0.5790 Iteration 08000, Loss: 0.5696 Iteration 08100, Loss: 0.2225 Iteration 08200, Loss: 1.1715 Iteration 08300, Loss: 0.1278 Iteration 08400, Loss: 0.1849 Iteration 08500, Loss: 0.2257 Iteration 08600, Loss: 0.2808 Iteration 08700, Loss: 0.5665 Iteration 08800, Loss: 0.3757 Iteration 08900, Loss: 0.6134 Iteration 09000, Loss: 0.2263 Iteration 09100, Loss: 0.5041 Iteration 09200, Loss: 0.2298 Iteration 09300, Loss: 0.3916 Iteration 09400, Loss: 0.2168 Iteration 09500, Loss: 0.2280 Iteration 09600, Loss: 0.2622 Iteration 09700, Loss: 0.9283 Iteration 09800, Loss: 1.1408 Iteration 09900, Loss: 0.7503 .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 8.720 seconds) .. _sphx_glr_download_examples_convolution_2d.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: convolution_2d.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: convolution_2d.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: convolution_2d.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_