Convolution in 2D#

Import packages#

First, we import the packages we need for this example.

import matplotlib.pyplot as plt
import numpy as np
import torch

import pytorch_finufft

Let’s create a Gaussian convolutional filter as a function of x,y

def gaussian_function(x, y, sigma=1):
    return np.exp(-(x**2 + y**2) / (2 * sigma**2))

Let’s visualize this filter kernel. We will be using it to convolve with points living on the \([0, 2*\pi] \times [0, 2*\pi]\) torus. So let’s dimension it accordingly.

shape = (128, 128)
sigma = 0.5
x = np.linspace(-np.pi, np.pi, shape[0], endpoint=False)
y = np.linspace(-np.pi, np.pi, shape[1], endpoint=False)

gaussian_kernel = gaussian_function(x[:, np.newaxis], y, sigma=sigma)

fig, ax = plt.subplots()
_ = ax.imshow(gaussian_kernel)
convolution 2d

In order for the kernel to not shift the signal, we need to place its mass at 0. To do this, we ifftshift the kernel

convolution 2d

Now let’s create a point cloud on the torus that we can convolve with our filter

N = 20
points = np.random.rand(2, N) * 2 * np.pi

fig, ax = plt.subplots()
ax.set_xlim(0, 2 * np.pi)
ax.set_ylim(0, 2 * np.pi)
ax.set_aspect("equal")
_ = ax.scatter(points[0], points[1], s=1)
convolution 2d

Now we can convolve the point cloud with the filter kernel. To do this, we Fourier-transform both the point cloud and the filter kernel, multiply them together, and then inverse Fourier-transform the result. First we need to convert all data to torch tensors

convolution 2d

We now have two possibilities: Invert the Fourier transform on a grid, or on a point cloud. We’ll first invert the Fourier transform on a grid in order to be able to visualize the effect of the convolution.

convolved_points = torch.fft.ifft2(fourier_points * fourier_shifted_gaussian_kernel)

fig, ax = plt.subplots()
ax.imshow(convolved_points.real)
_ = ax.scatter(
    points[1] / 2 / np.pi * shape[0], points[0] / 2 / np.pi * shape[1], s=2, c="r"
)
convolution 2d

We see that the convolution has smeared out the point cloud. After a small coordinate change, we can also plot the original points on the same plot as the convolved points.

Next, we invert the Fourier transform on the same points as our original point cloud. We will then compare this to direct evaluation of the kernel on all pairwise difference vectors between the points.

convolved_at_points = pytorch_finufft.functional.finufft_type2(
    torch.from_numpy(points),
    fourier_points * fourier_shifted_gaussian_kernel,
    isign=1,
).real / np.prod(shape)

fig, ax = plt.subplots()
ax.imshow(convolved_points.real)
_ = ax.scatter(
    points[1] / 2 / np.pi * shape[0],
    points[0] / 2 / np.pi * shape[1],
    s=10 * convolved_at_points,
    c="r",
)
convolution 2d

To compute the convolution directly, we need to evaluate the kernel on all pairwise difference vectors between the points. Note the points that will be off the diagonal. These will be due to the periodic boundary conditions of the convolution.

pairwise_diffs = points[:, np.newaxis] - points[:, :, np.newaxis]
kernel_diff_evals = gaussian_function(*pairwise_diffs, sigma=sigma)
convolved_by_hand = kernel_diff_evals.sum(1)

fig, ax = plt.subplots()
ax.plot(convolved_at_points.numpy(), convolved_by_hand, ".")
ax.plot([1, 3], [1, 3])

relative_difference = torch.norm(
    convolved_at_points - convolved_by_hand
) / np.linalg.norm(convolved_by_hand)
print(
    "Relative difference between fourier convolution and direct convolution "
    f"{relative_difference}"
)
convolution 2d
/home/runner/work/pytorch-finufft/pytorch-finufft/examples/convolution_2d.py:142: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
  convolved_at_points - convolved_by_hand
Relative difference between fourier convolution and direct convolution 0.04718525691856909

Now let’s see if we can learn the convolution kernel from the input and output point clouds. To this end, let’s first make a pytorch object that can compute a kernel convolution on a point cloud.

class FourierPointConvolution(torch.nn.Module):
    def __init__(self, fourier_kernel_shape):
        super().__init__()
        self.fourier_kernel_shape = fourier_kernel_shape

        self.build()

    def build(self):
        self.register_parameter(
            "fourier_kernel",
            torch.nn.Parameter(
                torch.randn(self.fourier_kernel_shape, dtype=torch.complex128)
            ),
        )
        # ^ think about whether we need to scale this init in some better way

    def forward(self, points, values):
        fourier_transformed_input = pytorch_finufft.functional.finufft_type1(
            points, values, self.fourier_kernel_shape
        )
        fourier_convolved = fourier_transformed_input * self.fourier_kernel
        convolved = pytorch_finufft.functional.finufft_type2(
            points,
            fourier_convolved,
            isign=1,
        ).real / np.prod(self.fourier_kernel_shape)
        return convolved

Now we can use this object in a pytorch training loop to learn the kernel from the input and output point clouds. We will use the mean squared error as a loss function.

fourier_point_convolution = FourierPointConvolution(shape)
optimizer = torch.optim.AdamW(
    fourier_point_convolution.parameters(), lr=0.005, weight_decay=0.001
)

ones = torch.ones(points.shape[1], dtype=torch.complex128)

losses = []
for i in range(10000):
    # Make new set of points and compute forward model
    points = np.random.rand(2, N) * 2 * np.pi
    torch_points = torch.from_numpy(points)
    fourier_points = pytorch_finufft.functional.finufft_type1(
        torch.from_numpy(points), ones, shape
    )
    convolved_at_points = pytorch_finufft.functional.finufft_type2(
        torch.from_numpy(points),
        fourier_points * fourier_shifted_gaussian_kernel,
        isign=1,
    ).real / np.prod(shape)

    # Learning step
    optimizer.zero_grad()
    convolved = fourier_point_convolution(torch_points, ones)
    loss = torch.nn.functional.mse_loss(convolved, convolved_at_points)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()

    if i % 100 == 0:
        print(f"Iteration {i:05d}, Loss: {loss.item():1.4f}")


fig, ax = plt.subplots()
ax.plot(losses)
ax.set_ylabel("Loss")
ax.set_xlabel("Iteration")
ax.set_yscale("log")

fig, ax = plt.subplots()
im = ax.imshow(
    torch.real(torch.fft.fftshift(fourier_point_convolution.fourier_kernel.data))[
        48:80, 48:80
    ]
)
_ = fig.colorbar(im, ax=ax)
  • convolution 2d
  • convolution 2d
Iteration 00000, Loss: 2.5809
Iteration 00100, Loss: 1.4808
Iteration 00200, Loss: 1.1741
Iteration 00300, Loss: 0.9850
Iteration 00400, Loss: 0.9863
Iteration 00500, Loss: 2.5158
Iteration 00600, Loss: 0.3715
Iteration 00700, Loss: 0.3903
Iteration 00800, Loss: 0.2631
Iteration 00900, Loss: 0.2134
Iteration 01000, Loss: 0.1860
Iteration 01100, Loss: 0.3245
Iteration 01200, Loss: 0.2940
Iteration 01300, Loss: 0.2647
Iteration 01400, Loss: 0.2933
Iteration 01500, Loss: 0.1928
Iteration 01600, Loss: 0.3202
Iteration 01700, Loss: 0.2680
Iteration 01800, Loss: 0.1661
Iteration 01900, Loss: 1.1658
Iteration 02000, Loss: 0.2282
Iteration 02100, Loss: 0.2269
Iteration 02200, Loss: 0.3766
Iteration 02300, Loss: 0.4789
Iteration 02400, Loss: 0.1356
Iteration 02500, Loss: 0.7934
Iteration 02600, Loss: 0.1814
Iteration 02700, Loss: 0.2255
Iteration 02800, Loss: 0.3606
Iteration 02900, Loss: 0.3494
Iteration 03000, Loss: 0.2197
Iteration 03100, Loss: 0.1896
Iteration 03200, Loss: 0.4261
Iteration 03300, Loss: 0.2718
Iteration 03400, Loss: 0.2944
Iteration 03500, Loss: 0.2876
Iteration 03600, Loss: 0.2180
Iteration 03700, Loss: 0.2821
Iteration 03800, Loss: 0.3102
Iteration 03900, Loss: 0.1819
Iteration 04000, Loss: 0.4096
Iteration 04100, Loss: 0.2575
Iteration 04200, Loss: 0.8532
Iteration 04300, Loss: 0.2172
Iteration 04400, Loss: 0.2199
Iteration 04500, Loss: 0.1280
Iteration 04600, Loss: 0.6115
Iteration 04700, Loss: 0.2099
Iteration 04800, Loss: 0.4761
Iteration 04900, Loss: 0.1644
Iteration 05000, Loss: 0.1703
Iteration 05100, Loss: 1.2774
Iteration 05200, Loss: 0.9800
Iteration 05300, Loss: 0.1741
Iteration 05400, Loss: 0.5176
Iteration 05500, Loss: 0.1916
Iteration 05600, Loss: 0.2312
Iteration 05700, Loss: 0.1687
Iteration 05800, Loss: 0.2477
Iteration 05900, Loss: 0.3528
Iteration 06000, Loss: 0.1123
Iteration 06100, Loss: 0.1784
Iteration 06200, Loss: 0.2072
Iteration 06300, Loss: 0.1583
Iteration 06400, Loss: 0.2110
Iteration 06500, Loss: 1.6568
Iteration 06600, Loss: 0.1551
Iteration 06700, Loss: 0.4702
Iteration 06800, Loss: 0.3315
Iteration 06900, Loss: 0.2461
Iteration 07000, Loss: 0.2307
Iteration 07100, Loss: 0.2240
Iteration 07200, Loss: 0.1392
Iteration 07300, Loss: 0.1952
Iteration 07400, Loss: 0.1233
Iteration 07500, Loss: 0.3938
Iteration 07600, Loss: 0.1688
Iteration 07700, Loss: 0.3303
Iteration 07800, Loss: 0.4709
Iteration 07900, Loss: 0.4532
Iteration 08000, Loss: 0.1433
Iteration 08100, Loss: 0.1037
Iteration 08200, Loss: 0.5045
Iteration 08300, Loss: 0.2721
Iteration 08400, Loss: 0.4671
Iteration 08500, Loss: 0.1922
Iteration 08600, Loss: 0.2169
Iteration 08700, Loss: 0.4239
Iteration 08800, Loss: 0.1579
Iteration 08900, Loss: 0.1893
Iteration 09000, Loss: 0.1590
Iteration 09100, Loss: 0.1064
Iteration 09200, Loss: 0.2484
Iteration 09300, Loss: 0.1997
Iteration 09400, Loss: 0.1481
Iteration 09500, Loss: 0.3447
Iteration 09600, Loss: 0.2854
Iteration 09700, Loss: 0.3618
Iteration 09800, Loss: 0.2017
Iteration 09900, Loss: 0.2713

Total running time of the script: (1 minutes 10.252 seconds)

Gallery generated by Sphinx-Gallery