Skip to content

Commit 76687d6

Browse files
feat(tidy3d): FXC-3961-faster-convolutions-for-tidy-3-d-plugins-autograd-filters
1 parent 12998d9 commit 76687d6

File tree

3 files changed

+276
-29
lines changed

3 files changed

+276
-29
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5858
- `web.Batch(ComponentModeler)` and `web.Job(ComponentModeler)` native support
5959
- Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion.
6060
- Port names in `ModalComponentModeler` and `TerminalComponentModeler` can no longer include the `@` symbol.
61+
- Improved speed of convolutions for large inputs.
6162

6263
### Fixed
6364
- Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL.

tests/test_plugins/autograd/test_functions.py

Lines changed: 121 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
threshold,
2929
trapz,
3030
)
31+
from tidy3d.plugins.autograd.functions import _normalize_axes
3132
from tidy3d.plugins.autograd.types import PaddingType
3233

3334
_mode_to_scipy = {
@@ -38,6 +39,15 @@
3839
"wrap": "wrap",
3940
}
4041

42+
CONV_MODES = ["full", "same", "valid"]
43+
44+
_CONVOLVE_AXES_CASES = [
45+
([0], [0]),
46+
([1], [1]),
47+
([1], [0]),
48+
([-1], [-1]),
49+
]
50+
4151

4252
@pytest.mark.parametrize("mode", PaddingType.__args__)
4353
@pytest.mark.parametrize("size", [3, 4, (3, 3), (4, 4), (3, 4), (3, 3, 3), (4, 4, 4), (3, 4, 5)])
@@ -94,7 +104,7 @@ def test_negative_axis_out_of_range(self):
94104
pad(self.array, (1, 1), axis=-3)
95105

96106

97-
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
107+
@pytest.mark.parametrize("mode", CONV_MODES)
98108
@pytest.mark.parametrize("padding", PaddingType.__args__)
99109
@pytest.mark.parametrize(
100110
"ary_size", [7, 8, (7, 7), (8, 8), (7, 8), (7, 7, 7), (8, 8, 8), (7, 8, 9)]
@@ -117,22 +127,10 @@ def test_convolve_val(self, rng, mode, padding, ary_size, kernel_size, square_ke
117127
"""Test convolution values against SciPy for various modes, padding, array sizes, and kernel sizes."""
118128
x, k = self._ary_and_kernel(rng, ary_size, kernel_size, square_kernel)
119129

120-
if mode in ("full", "same"):
121-
pad_widths = [(k // 2, k // 2) for k in k.shape]
122-
x_padded = x
123-
for axis, pad_width in enumerate(pad_widths):
124-
x_padded = pad(x_padded, pad_width, mode=padding, axis=axis)
125-
conv_sp = convolve_sp(x_padded, k, mode="valid" if mode == "same" else mode)
126-
else:
127-
conv_sp = convolve_sp(x, k, mode=mode)
128-
129130
conv_td = convolve(x, k, padding=padding, mode=mode)
131+
conv_sp = _reference_convolution(x, k, mode, padding, axes=None)
130132

131-
npt.assert_allclose(
132-
conv_td,
133-
conv_sp,
134-
atol=1e-12, # scipy's "full" somehow is not zero at the edges...
135-
)
133+
npt.assert_allclose(conv_td, conv_sp, atol=1e-12)
136134

137135
def test_convolve_grad(self, rng, mode, padding, ary_size, kernel_size, square_kernel):
138136
"""Test gradients of convolution function for various modes, padding, array sizes, and kernel sizes."""
@@ -168,6 +166,114 @@ def test_kernel_array_dimension_mismatch(self):
168166
convolve(self.array, kernel_mismatch)
169167

170168

169+
def _reference_convolve_with_axes(array, kernel, axes_array, axes_kernel, mode):
170+
"""Construct a SciPy reference for convolutions with explicit axes."""
171+
172+
array_batch_axes = tuple(ax for ax in range(array.ndim) if ax not in axes_array)
173+
kernel_batch_axes = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel)
174+
175+
array_perm = array_batch_axes + axes_array
176+
kernel_perm = kernel_batch_axes + axes_kernel
177+
178+
array_reordered = np.transpose(array, array_perm)
179+
kernel_reordered = np.transpose(kernel, kernel_perm)
180+
181+
len_array_batch = len(array_batch_axes)
182+
len_kernel_batch = len(kernel_batch_axes)
183+
184+
array_batch_shape = array_reordered.shape[:len_array_batch]
185+
kernel_batch_shape = kernel_reordered.shape[:len_kernel_batch]
186+
187+
sample_conv = convolve_sp(
188+
array_reordered[(0,) * len_array_batch],
189+
kernel_reordered[(0,) * len_kernel_batch],
190+
mode=mode,
191+
)
192+
conv_shape = sample_conv.shape
193+
194+
expected = np.empty(array_batch_shape + kernel_batch_shape + conv_shape)
195+
196+
for idx_array in np.ndindex(array_batch_shape):
197+
array_slice = array_reordered[idx_array]
198+
for idx_kernel in np.ndindex(kernel_batch_shape):
199+
kernel_slice = kernel_reordered[idx_kernel]
200+
expected[idx_array + idx_kernel] = convolve_sp(array_slice, kernel_slice, mode=mode)
201+
202+
return expected
203+
204+
205+
def _prepare_reference_inputs(array, kernel, mode, padding, axes):
206+
"""Apply padding logic to match tidy3d's convolution before building a reference."""
207+
208+
axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes)
209+
210+
working_array = array
211+
scipy_mode = mode
212+
213+
if mode in ("same", "full"):
214+
for ax_array, ax_kernel in zip(axes_array, axes_kernel):
215+
pad_width = (
216+
kernel.shape[ax_kernel] // 2 if mode == "same" else kernel.shape[ax_kernel] - 1
217+
)
218+
if pad_width > 0:
219+
working_array = pad(
220+
working_array, (pad_width, pad_width), mode=padding, axis=ax_array
221+
)
222+
scipy_mode = "valid"
223+
224+
working_array_np = np.asarray(working_array)
225+
kernel_np = np.asarray(kernel)
226+
227+
return working_array_np, kernel_np, axes_array, axes_kernel, scipy_mode
228+
229+
230+
def _reference_convolution(array, kernel, mode, padding, axes):
231+
"""Full reference that mimics tidy3d padding rules before SciPy convolution."""
232+
233+
working_array_np, kernel_np, axes_array, axes_kernel, scipy_mode = _prepare_reference_inputs(
234+
array,
235+
kernel,
236+
mode,
237+
padding,
238+
axes,
239+
)
240+
241+
return _reference_convolve_with_axes(
242+
working_array_np,
243+
kernel_np,
244+
axes_array,
245+
axes_kernel,
246+
scipy_mode,
247+
)
248+
249+
250+
@pytest.mark.parametrize("mode", CONV_MODES)
251+
@pytest.mark.parametrize("padding", PaddingType.__args__)
252+
@pytest.mark.parametrize("axes", _CONVOLVE_AXES_CASES)
253+
class TestConvolveAxes:
254+
def test_convolve_axes_val(self, rng, mode, padding, axes):
255+
"""Test convolution with explicit axes against NumPy implementations."""
256+
array = rng.random((2, 5))
257+
kernel = rng.random((3, 3))
258+
259+
conv_td = convolve(array, kernel, padding=padding, mode=mode, axes=axes)
260+
expected = _reference_convolution(array, kernel, mode, padding, axes)
261+
262+
npt.assert_allclose(conv_td, expected, atol=1e-12)
263+
264+
def test_convolve_axes_grad(self, rng, axes, mode, padding):
265+
"""Test gradients of convolution when specific axes are provided."""
266+
array = rng.random((2, 5))
267+
kernel = rng.random((3, 3))
268+
check_grads(convolve, modes=["rev"], order=2)(
269+
array,
270+
kernel,
271+
padding=padding,
272+
mode=mode,
273+
axes=axes,
274+
)
275+
276+
171277
@pytest.mark.parametrize(
172278
"op,sp_op",
173279
[

tidy3d/plugins/autograd/functions.py

Lines changed: 154 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterable
4-
from typing import Callable, Literal, Union
4+
from typing import Callable, Literal, SupportsInt, Union
55

66
import autograd.numpy as np
77
import numpy as onp
88
from autograd import jacobian
99
from autograd.extend import defvjp, primitive
10-
from autograd.scipy.signal import convolve as convolve_ag
10+
from autograd.numpy.fft import fftn, ifftn
1111
from autograd.scipy.special import logsumexp
1212
from autograd.tracer import getval
13+
from numpy.fft import irfftn, rfftn
1314
from numpy.lib.stride_tricks import sliding_window_view
1415
from numpy.typing import NDArray
16+
from scipy.fft import next_fast_len
1517

1618
from tidy3d.components.autograd.functions import add_at, interpn, trapz
1719

@@ -37,6 +39,140 @@
3739
]
3840

3941

42+
def _normalize_axes(
43+
ndim_array: int,
44+
ndim_kernel: int,
45+
axes: Union[tuple[Iterable[SupportsInt], Iterable[SupportsInt]], None],
46+
) -> tuple[tuple[int, ...], tuple[int, ...]]:
47+
"""Normalize the axes specification for convolution."""
48+
49+
def _normalize_single_axis(ax: SupportsInt, ndim: int, kind: str) -> int:
50+
if not isinstance(ax, int):
51+
try:
52+
ax = int(ax)
53+
except Exception as e:
54+
raise TypeError(f"Axis {ax!r} could not be converted to an integer.") from e
55+
56+
if not -ndim <= ax < ndim:
57+
raise ValueError(f"Invalid axis {ax} for {kind} with ndim {ndim}.")
58+
return ax + ndim if ax < 0 else ax
59+
60+
if axes is None:
61+
if ndim_array != ndim_kernel:
62+
raise ValueError(
63+
"Kernel dimensions must match array dimensions when 'axes' is not provided, "
64+
f"got array ndim {ndim_array} and kernel ndim {ndim_kernel}."
65+
)
66+
axes_array = tuple(range(ndim_array))
67+
axes_kernel = tuple(range(ndim_kernel))
68+
return axes_array, axes_kernel
69+
70+
if len(axes) != 2:
71+
raise ValueError("'axes' must be a tuple of two iterable collections of axis indices.")
72+
73+
axes_array_raw, axes_kernel_raw = axes
74+
75+
axes_array = tuple(_normalize_single_axis(ax, ndim_array, "array") for ax in axes_array_raw)
76+
axes_kernel = tuple(_normalize_single_axis(ax, ndim_kernel, "kernel") for ax in axes_kernel_raw)
77+
78+
if len(axes_array) != len(axes_kernel):
79+
raise ValueError(
80+
"The number of convolution axes for the array and kernel must be the same, "
81+
f"got {len(axes_array)} and {len(axes_kernel)}."
82+
)
83+
84+
if len(set(axes_array)) != len(axes_array) or len(set(axes_kernel)) != len(axes_kernel):
85+
raise ValueError("Convolution axes must be unique for both the array and the kernel.")
86+
87+
return axes_array, axes_kernel
88+
89+
90+
def _fft_convolve_general(
91+
array: NDArray,
92+
kernel: NDArray,
93+
axes_array: tuple[int, ...],
94+
axes_kernel: tuple[int, ...],
95+
mode: Literal["full", "valid"],
96+
) -> NDArray:
97+
"""Perform convolution using FFT along the specified axes."""
98+
99+
num_conv_axes = len(axes_array)
100+
101+
if num_conv_axes == 0:
102+
array_shape = array.shape
103+
kernel_shape = kernel.shape
104+
result = np.multiply(
105+
array.reshape(array_shape + (1,) * kernel.ndim),
106+
kernel.reshape((1,) * array.ndim + kernel_shape),
107+
)
108+
return result.reshape(array_shape + kernel_shape)
109+
110+
ignore_axes_array = tuple(ax for ax in range(array.ndim) if ax not in axes_array)
111+
ignore_axes_kernel = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel)
112+
113+
new_order_array = ignore_axes_array + axes_array
114+
new_order_kernel = ignore_axes_kernel + axes_kernel
115+
116+
array_reordered = np.transpose(array, new_order_array) if array.ndim else array
117+
kernel_reordered = np.transpose(kernel, new_order_kernel) if kernel.ndim else kernel
118+
119+
num_batch_array = len(ignore_axes_array)
120+
num_batch_kernel = len(ignore_axes_kernel)
121+
122+
array_conv_shape = array_reordered.shape[num_batch_array:]
123+
kernel_conv_shape = kernel_reordered.shape[num_batch_kernel:]
124+
125+
if any(d <= 0 for d in array_conv_shape + kernel_conv_shape):
126+
raise ValueError("Convolution dimensions must be positive; got zero-length axis.")
127+
128+
fft_axes = tuple(range(-num_conv_axes, 0))
129+
fft_shape = [next_fast_len(n + k - 1) for n, k in zip(array_conv_shape, kernel_conv_shape)]
130+
use_real_fft = fft_shape[-1] % 2 == 0 # only applicable in this case
131+
132+
fft_fun = rfftn if use_real_fft else fftn
133+
array_fft = fft_fun(array_reordered, fft_shape, axes=fft_axes)
134+
kernel_fft = fft_fun(kernel_reordered, fft_shape, axes=fft_axes)
135+
136+
if num_batch_kernel:
137+
array_batch_shape = array_fft.shape[:num_batch_array]
138+
conv_shape = array_fft.shape[num_batch_array:]
139+
array_fft = np.reshape(
140+
array_fft,
141+
array_batch_shape + (1,) * num_batch_kernel + conv_shape,
142+
)
143+
144+
if num_batch_array:
145+
kernel_batch_shape = kernel_fft.shape[:num_batch_kernel]
146+
conv_shape = kernel_fft.shape[num_batch_kernel:]
147+
kernel_fft = np.reshape(
148+
kernel_fft,
149+
(1,) * num_batch_array + kernel_batch_shape + conv_shape,
150+
)
151+
use_real_fft = fft_shape[-1] % 2 == 0
152+
153+
product = array_fft * kernel_fft
154+
155+
ifft_fun = irfftn if use_real_fft else ifftn
156+
full_result = ifft_fun(product, fft_shape, axes=fft_axes)
157+
158+
if mode == "full":
159+
result = full_result
160+
elif mode == "valid":
161+
valid_slices = [slice(None)] * full_result.ndim
162+
for axis_offset, (array_dim, kernel_dim) in enumerate(
163+
zip(array_conv_shape, kernel_conv_shape)
164+
):
165+
start = int(min(array_dim, kernel_dim) - 1)
166+
length = int(abs(array_dim - kernel_dim) + 1)
167+
axis = full_result.ndim - num_conv_axes + axis_offset
168+
valid_slices[axis] = slice(start, start + length)
169+
result = full_result[tuple(valid_slices)]
170+
else:
171+
raise ValueError(f"Unsupported convolution mode '{mode}'.")
172+
173+
return np.real(result)
174+
175+
40176
def _get_pad_indices(
41177
n: int,
42178
pad_width: tuple[int, int],
@@ -148,7 +284,7 @@ def convolve(
148284
kernel: NDArray,
149285
*,
150286
padding: PaddingType = "constant",
151-
axes: Union[tuple[list[int], list[int]], None] = None,
287+
axes: Union[tuple[list[SupportsInt], list[SupportsInt]], None] = None,
152288
mode: Literal["full", "valid", "same"] = "same",
153289
) -> NDArray:
154290
"""Convolve an array with a given kernel.
@@ -180,19 +316,23 @@ def convolve(
180316
if any(k % 2 == 0 for k in kernel.shape):
181317
raise ValueError(f"All kernel dimensions must be odd, got {kernel.shape}.")
182318

183-
if kernel.ndim != array.ndim and axes is None:
184-
raise ValueError(
185-
f"Kernel dimensions must match array dimensions, got kernel {kernel.shape} and array {array.shape}."
186-
)
319+
axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes)
320+
321+
working_array = array
322+
effective_mode = mode
187323

188-
if mode in ("same", "full"):
189-
kernel_dims = kernel.shape if axes is None else [kernel.shape[d] for d in axes[1]]
190-
pad_widths = [(ks // 2, ks // 2) for ks in kernel_dims]
191-
for axis, pad_width in enumerate(pad_widths):
192-
array = pad(array, pad_width, mode=padding, axis=axis)
193-
mode = "valid" if mode == "same" else mode
324+
if mode in ["same", "full"]:
325+
for ax_array, ax_kernel in zip(axes_array, axes_kernel):
326+
pad_width = (
327+
kernel.shape[ax_kernel] // 2 if mode == "same" else kernel.shape[ax_kernel] - 1
328+
)
329+
if pad_width > 0:
330+
working_array = pad(
331+
working_array, (pad_width, pad_width), mode=padding, axis=ax_array
332+
)
333+
effective_mode = "valid"
194334

195-
return convolve_ag(array, kernel, axes=axes, mode=mode)
335+
return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode)
196336

197337

198338
def _get_footprint(size, structure, maxval):

0 commit comments

Comments
 (0)