Skip to content

Commit 24ea2ac

Browse files
feat(tidy3d): FXC-3961-faster-convolutions-for-tidy-3-d-plugins-autograd-filters
1 parent 989e807 commit 24ea2ac

File tree

3 files changed

+208
-17
lines changed

3 files changed

+208
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5555
- Unified run submission API: `web.run(...)` is now a container-aware wrapper that accepts a single simulation or arbitrarily nested containers (`list`, `tuple`, `dict` values) and returns results in the same shape.
5656
- `web.Batch(ComponentModeler)` and `web.Job(ComponentModeler)` native support
5757
- Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion.
58+
- Improved speed of autograd tracing for convolutions.
5859

5960
### Fixed
6061
- Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL.

tests/test_plugins/autograd/test_functions.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,15 @@ def test_convolve_val(self, rng, mode, padding, ary_size, kernel_size, square_ke
117117
"""Test convolution values against SciPy for various modes, padding, array sizes, and kernel sizes."""
118118
x, k = self._ary_and_kernel(rng, ary_size, kernel_size, square_kernel)
119119

120-
if mode in ("full", "same"):
121-
pad_widths = [(k // 2, k // 2) for k in k.shape]
120+
if mode in ["same", "full"]:
121+
if mode == "same":
122+
pad_widths = [(k // 2, k // 2) for k in k.shape]
123+
else:
124+
pad_widths = [(k - 1, k - 1) for k in k.shape]
122125
x_padded = x
123126
for axis, pad_width in enumerate(pad_widths):
124127
x_padded = pad(x_padded, pad_width, mode=padding, axis=axis)
125-
conv_sp = convolve_sp(x_padded, k, mode="valid" if mode == "same" else mode)
128+
conv_sp = convolve_sp(x_padded, k, mode="valid")
126129
else:
127130
conv_sp = convolve_sp(x, k, mode=mode)
128131

@@ -168,6 +171,52 @@ def test_kernel_array_dimension_mismatch(self):
168171
convolve(self.array, kernel_mismatch)
169172

170173

174+
class TestConvolveAxes:
175+
@pytest.mark.parametrize("mode", ["valid", "same", "full"])
176+
@pytest.mark.parametrize("padding", ["constant", "edge"])
177+
def test_convolve_axes_val(self, rng, mode, padding):
178+
"""Test convolution with explicit axes against NumPy implementations."""
179+
array = rng.random((2, 5))
180+
kernel = rng.random((3, 3))
181+
axes = ([1], [1])
182+
183+
conv_td = convolve(array, kernel, padding=padding, mode=mode, axes=axes)
184+
185+
working_array = array
186+
scipy_mode = mode
187+
if mode in ("same", "full"):
188+
pad_width = kernel.shape[1] // 2 if mode == "same" else kernel.shape[1] - 1
189+
working_array = pad(array, (pad_width, pad_width), mode=padding, axis=1)
190+
scipy_mode = "valid"
191+
192+
working_array_np = np.asarray(working_array)
193+
kernel_np = np.asarray(kernel)
194+
conv_length = np.convolve(working_array_np[0], kernel_np[0], mode=scipy_mode).shape[0]
195+
196+
expected = np.empty((array.shape[0], kernel.shape[0], conv_length))
197+
for i in range(array.shape[0]):
198+
for j in range(kernel.shape[0]):
199+
expected[i, j] = np.convolve(
200+
working_array_np[i],
201+
kernel_np[j],
202+
mode=scipy_mode,
203+
)
204+
205+
npt.assert_allclose(conv_td, expected, atol=1e-12)
206+
207+
def test_convolve_axes_grad(self, rng):
208+
"""Test gradients of convolution when specific axes are provided."""
209+
array = rng.random((2, 5))
210+
kernel = rng.random((3, 3))
211+
check_grads(convolve, modes=["rev"], order=2)(
212+
array,
213+
kernel,
214+
padding="constant",
215+
mode="valid",
216+
axes=([1], [1]),
217+
)
218+
219+
171220
@pytest.mark.parametrize(
172221
"op,sp_op",
173222
[

tidy3d/plugins/autograd/functions.py

Lines changed: 155 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as onp
88
from autograd import jacobian
99
from autograd.extend import defvjp, primitive
10-
from autograd.scipy.signal import convolve as convolve_ag
10+
from autograd.numpy.fft import fftn, ifftn
1111
from autograd.scipy.special import logsumexp
1212
from autograd.tracer import getval
1313
from numpy.lib.stride_tricks import sliding_window_view
@@ -37,6 +37,139 @@
3737
]
3838

3939

40+
def _normalize_axes(
41+
ndim_array: int,
42+
ndim_kernel: int,
43+
axes: Union[tuple[Iterable[int], Iterable[int]], None],
44+
) -> tuple[tuple[int, ...], tuple[int, ...]]:
45+
"""Normalize the axes specification for convolution."""
46+
47+
def _normalize_single_axis(ax: int, ndim: int, kind: str) -> int:
48+
if not isinstance(ax, int):
49+
raise TypeError(
50+
f"Axis indices must be integers, got {ax!r} of type {type(ax).__name__}."
51+
)
52+
53+
if not -ndim <= ax < ndim:
54+
raise ValueError(f"Invalid axis {ax} for {kind} with ndim {ndim}.")
55+
return ax + ndim if ax < 0 else ax
56+
57+
if axes is None:
58+
if ndim_array != ndim_kernel:
59+
raise ValueError(
60+
"Kernel dimensions must match array dimensions when 'axes' is not provided, "
61+
f"got array ndim {ndim_array} and kernel ndim {ndim_kernel}."
62+
)
63+
axes_array = tuple(range(ndim_array))
64+
axes_kernel = tuple(range(ndim_kernel))
65+
return axes_array, axes_kernel
66+
67+
if len(axes) != 2:
68+
raise ValueError("'axes' must be a tuple of two iterable collections of axis indices.")
69+
70+
axes_array_raw, axes_kernel_raw = axes
71+
72+
axes_array = tuple(_normalize_single_axis(ax, ndim_array, "array") for ax in axes_array_raw)
73+
axes_kernel = tuple(_normalize_single_axis(ax, ndim_kernel, "kernel") for ax in axes_kernel_raw)
74+
75+
if len(axes_array) != len(axes_kernel):
76+
raise ValueError(
77+
"The number of convolution axes for the array and kernel must be the same, "
78+
f"got {len(axes_array)} and {len(axes_kernel)}."
79+
)
80+
81+
if len(set(axes_array)) != len(axes_array) or len(set(axes_kernel)) != len(axes_kernel):
82+
raise ValueError("Convolution axes must be unique for both the array and the kernel.")
83+
84+
return axes_array, axes_kernel
85+
86+
87+
def _fft_convolve_general(
88+
array: NDArray,
89+
kernel: NDArray,
90+
axes_array: tuple[int, ...],
91+
axes_kernel: tuple[int, ...],
92+
mode: Literal["full", "valid"],
93+
) -> NDArray:
94+
"""Perform convolution using FFT along the specified axes."""
95+
96+
num_conv_axes = len(axes_array)
97+
98+
if num_conv_axes == 0:
99+
array_shape = array.shape
100+
kernel_shape = kernel.shape
101+
result = np.multiply(
102+
array.reshape(array_shape + (1,) * kernel.ndim),
103+
kernel.reshape((1,) * array.ndim + kernel_shape),
104+
)
105+
return result.reshape(array_shape + kernel_shape)
106+
107+
ignore_axes_array = tuple(ax for ax in range(array.ndim) if ax not in axes_array)
108+
ignore_axes_kernel = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel)
109+
110+
new_order_array = ignore_axes_array + axes_array
111+
new_order_kernel = ignore_axes_kernel + axes_kernel
112+
113+
array_reordered = np.transpose(array, new_order_array) if array.ndim else array
114+
kernel_reordered = np.transpose(kernel, new_order_kernel) if kernel.ndim else kernel
115+
116+
num_batch_array = len(ignore_axes_array)
117+
num_batch_kernel = len(ignore_axes_kernel)
118+
119+
array_conv_shape = array_reordered.shape[num_batch_array:]
120+
kernel_conv_shape = kernel_reordered.shape[num_batch_kernel:]
121+
122+
if any(d <= 0 for d in array_conv_shape + kernel_conv_shape):
123+
raise ValueError("Convolution dimensions must be positive; got zero-length axis.")
124+
125+
fft_axes = tuple(range(-num_conv_axes, 0))
126+
fft_shape = tuple(
127+
int(array_dim + kernel_dim - 1)
128+
for array_dim, kernel_dim in zip(array_conv_shape, kernel_conv_shape)
129+
)
130+
131+
array_fft = fftn(array_reordered, fft_shape, axes=fft_axes)
132+
kernel_fft = fftn(kernel_reordered, fft_shape, axes=fft_axes)
133+
134+
if num_batch_kernel:
135+
array_batch_shape = array_fft.shape[:num_batch_array]
136+
conv_shape = array_fft.shape[num_batch_array:]
137+
array_fft = np.reshape(
138+
array_fft,
139+
array_batch_shape + (1,) * num_batch_kernel + conv_shape,
140+
)
141+
142+
if num_batch_array:
143+
kernel_batch_shape = kernel_fft.shape[:num_batch_kernel]
144+
conv_shape = kernel_fft.shape[num_batch_kernel:]
145+
kernel_fft = np.reshape(
146+
kernel_fft,
147+
(1,) * num_batch_array + kernel_batch_shape + conv_shape,
148+
)
149+
150+
full_result = ifftn(array_fft * kernel_fft, fft_shape, axes=fft_axes)
151+
152+
if mode == "full":
153+
result = full_result
154+
elif mode == "valid":
155+
valid_slices = [slice(None)] * full_result.ndim
156+
for axis_offset, (array_dim, kernel_dim) in enumerate(
157+
zip(array_conv_shape, kernel_conv_shape)
158+
):
159+
start = int(min(array_dim, kernel_dim) - 1)
160+
length = int(abs(array_dim - kernel_dim) + 1)
161+
axis = full_result.ndim - num_conv_axes + axis_offset
162+
valid_slices[axis] = slice(start, start + length)
163+
result = full_result[tuple(valid_slices)]
164+
else:
165+
raise ValueError(f"Unsupported convolution mode '{mode}'.")
166+
167+
if not np.iscomplexobj(array) and not np.iscomplexobj(kernel):
168+
result = np.real(result)
169+
170+
return result
171+
172+
40173
def _get_pad_indices(
41174
n: int,
42175
pad_width: tuple[int, int],
@@ -180,19 +313,27 @@ def convolve(
180313
if any(k % 2 == 0 for k in kernel.shape):
181314
raise ValueError(f"All kernel dimensions must be odd, got {kernel.shape}.")
182315

183-
if kernel.ndim != array.ndim and axes is None:
184-
raise ValueError(
185-
f"Kernel dimensions must match array dimensions, got kernel {kernel.shape} and array {array.shape}."
186-
)
187-
188-
if mode in ("same", "full"):
189-
kernel_dims = kernel.shape if axes is None else [kernel.shape[d] for d in axes[1]]
190-
pad_widths = [(ks // 2, ks // 2) for ks in kernel_dims]
191-
for axis, pad_width in enumerate(pad_widths):
192-
array = pad(array, pad_width, mode=padding, axis=axis)
193-
mode = "valid" if mode == "same" else mode
194-
195-
return convolve_ag(array, kernel, axes=axes, mode=mode)
316+
axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes)
317+
318+
working_array = array
319+
effective_mode = mode
320+
321+
if mode in ["same", "full"]:
322+
for ax_array, ax_kernel in zip(axes_array, axes_kernel):
323+
if mode == "same" and kernel.shape[ax_kernel] % 2 == 0:
324+
raise ValueError(
325+
f"Even-sized kernel along axis {ax_kernel} not supported for 'same' mode."
326+
)
327+
pad_width = (
328+
kernel.shape[ax_kernel] // 2 if mode == "same" else kernel.shape[ax_kernel] - 1
329+
)
330+
if pad_width > 0:
331+
working_array = pad(
332+
working_array, (pad_width, pad_width), mode=padding, axis=ax_array
333+
)
334+
effective_mode = "valid"
335+
336+
return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode)
196337

197338

198339
def _get_footprint(size, structure, maxval):

0 commit comments

Comments
 (0)