Skip to content

Commit

Permalink
Add padding='same' for dilation=1, stride=1
Browse files Browse the repository at this point in the history
  • Loading branch information
papkov committed Mar 21, 2022
1 parent 40fcd51 commit f13c3eb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
32 changes: 25 additions & 7 deletions fft_conv_pytorch/fft_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as f
from torch import Tensor, nn
from torch.fft import irfftn, rfftn
from math import ceil, floor


def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
Expand Down Expand Up @@ -55,7 +56,7 @@ def fft_conv(
signal: Tensor,
kernel: Tensor,
bias: Tensor = None,
padding: Union[int, Iterable[int]] = 0,
padding: Union[int, Iterable[int], str] = 0,
padding_mode: str = "constant",
stride: Union[int, Iterable[int]] = 1,
dilation: Union[int, Iterable[int]] = 1,
Expand All @@ -69,19 +70,31 @@ def fft_conv(
signal: (Tensor) Input tensor to be convolved with the kernel.
kernel: (Tensor) Convolution kernel.
bias: (Tensor) Bias tensor to add to the output.
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
input on the last dimension.
padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
input on the last dimension. If str, "same" supported to pad input for size preservation.
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
reflection not available for 3d.
stride: (Union[int, Iterable[int]) Stride size for computing output values.
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
groups: (int) Number of groups for the convolution.
Returns:
(Tensor) Convolved tensor
"""

# Cast padding, stride & dilation to tuples.
n = signal.ndim - 2
padding_ = to_ntuple(padding, n=n)
stride_ = to_ntuple(stride, n=n)
dilation_ = to_ntuple(dilation, n=n)
if isinstance(padding, str):
if padding == "same":
if stride != 1 or dilation != 1:
raise ValueError("stride must be 1 for padding='same'.")
padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
else:
raise ValueError(f"Padding mode {padding} not supported.")
else:
padding_ = to_ntuple(padding, n=n)

# internal dilation offsets
offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
Expand All @@ -93,8 +106,8 @@ def fft_conv(
# pad the kernel internally according to the dilation parameters
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]

# Pad the input signal & kernel tensors
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
# Pad the input signal & kernel tensors (round to support even sized convolutions)
signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
signal = f.pad(signal, signal_padding, mode=padding_mode)

# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
Expand Down Expand Up @@ -155,9 +168,14 @@ def __init__(
out_channels: (int) Number of channels in output tensors
kernel_size: (Union[int, Iterable[int]) Square radius of the kernel
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
input on the last dimension.
input on the last dimension. If str, "same" supported to pad input for size preservation.
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
reflection not available for 3d.
stride: (Union[int, Iterable[int]) Stride size for computing output values.
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
groups: (int) Number of groups for the convolution.
bias: (bool) If True, includes bias, which is added after convolution
ndim: (int) Number of dimensions of the input tensor.
"""
super().__init__()
self.in_channels = in_channels
Expand Down
12 changes: 10 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.mark.parametrize("out_channels", [2, 3])
@pytest.mark.parametrize("groups", [1, 2, 3])
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("padding", [0, 1])
@pytest.mark.parametrize("padding", [0, 1, "same"])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.parametrize("bias", [True])
Expand All @@ -30,6 +30,10 @@ def test_fft_conv_functional(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))

Expand Down Expand Up @@ -70,7 +74,7 @@ def test_fft_conv_functional(
@pytest.mark.parametrize("out_channels", [2, 3])
@pytest.mark.parametrize("groups", [1, 2, 3])
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("padding", [0, 1])
@pytest.mark.parametrize("padding", [0, 1, "same"])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.parametrize("bias", [True])
Expand All @@ -88,6 +92,10 @@ def test_fft_conv_backward_functional(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))

Expand Down
8 changes: 8 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def test_fft_conv_module(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))
fft_conv_layer = _FFTConv(
Expand Down Expand Up @@ -85,6 +89,10 @@ def test_fft_conv_backward_module(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))
fft_conv_layer = _FFTConv(
Expand Down

0 comments on commit f13c3eb

Please sign in to comment.