Skip to content

Commit 4df0cd5

Browse files
feat(tidy3d): FXC-3961-faster-convolutions-for-tidy-3-d-plugins-autograd-filters
1 parent e73fbb4 commit 4df0cd5

File tree

3 files changed

+184
-13
lines changed

3 files changed

+184
-13
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5353
- Unified run submission API: `web.run(...)` is now a container-aware wrapper that accepts a single simulation or arbitrarily nested containers (`list`, `tuple`, `dict` values) and returns results in the same shape.
5454
- `web.Batch(ComponentModeler)` and `web.Job(ComponentModeler)` native support
5555
- Simulation data of batch jobs are now automatically downloaded upon their individual completion in `Batch.run()`, avoiding waiting for the entire batch to reach completion.
56-
56+
- Improved speed of autograd tracing for convolutions.
57+
-
5758
### Fixed
5859
- Ensured the legacy `Env` proxy mirrors `config.web` profile switches and preserves API URL.
5960
- More robust `Sellmeier` and `Debye` material model, and prevent very large pole parameters in `PoleResidue` material model.

tests/test_plugins/autograd/test_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,52 @@ def test_kernel_array_dimension_mismatch(self):
168168
convolve(self.array, kernel_mismatch)
169169

170170

171+
class TestConvolveAxes:
172+
@pytest.mark.parametrize("mode", ["valid", "same", "full"])
173+
@pytest.mark.parametrize("padding", ["constant", "edge"])
174+
def test_convolve_axes_val(self, rng, mode, padding):
175+
"""Test convolution with explicit axes against NumPy implementations."""
176+
array = rng.random((2, 5))
177+
kernel = rng.random((3, 3))
178+
axes = ([1], [1])
179+
180+
conv_td = convolve(array, kernel, padding=padding, mode=mode, axes=axes)
181+
182+
working_array = array
183+
scipy_mode = mode
184+
if mode in ("same", "full"):
185+
pad_width = kernel.shape[1] // 2
186+
working_array = pad(array, (pad_width, pad_width), mode=padding, axis=1)
187+
scipy_mode = "valid" if mode == "same" else mode
188+
189+
working_array_np = np.asarray(working_array)
190+
kernel_np = np.asarray(kernel)
191+
conv_length = np.convolve(working_array_np[0], kernel_np[0], mode=scipy_mode).shape[0]
192+
193+
expected = np.empty((array.shape[0], kernel.shape[0], conv_length))
194+
for i in range(array.shape[0]):
195+
for j in range(kernel.shape[0]):
196+
expected[i, j] = np.convolve(
197+
working_array_np[i],
198+
kernel_np[j],
199+
mode=scipy_mode,
200+
)
201+
202+
npt.assert_allclose(conv_td, expected, atol=1e-12)
203+
204+
def test_convolve_axes_grad(self, rng):
205+
"""Test gradients of convolution when specific axes are provided."""
206+
array = rng.random((2, 5))
207+
kernel = rng.random((3, 3))
208+
check_grads(convolve, modes=["rev"], order=2)(
209+
array,
210+
kernel,
211+
padding="constant",
212+
mode="valid",
213+
axes=([1], [1]),
214+
)
215+
216+
171217
@pytest.mark.parametrize(
172218
"op,sp_op",
173219
[

tidy3d/plugins/autograd/functions.py

Lines changed: 136 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as onp
88
from autograd import jacobian
99
from autograd.extend import defvjp, primitive
10-
from autograd.scipy.signal import convolve as convolve_ag
10+
from autograd.numpy.fft import fftn, ifftn
1111
from autograd.scipy.special import logsumexp
1212
from autograd.tracer import getval
1313
from numpy.lib.stride_tricks import sliding_window_view
@@ -37,6 +37,128 @@
3737
]
3838

3939

40+
def _normalize_axes(
41+
ndim_array: int,
42+
ndim_kernel: int,
43+
axes: Union[tuple[Iterable[int], Iterable[int]], None],
44+
) -> tuple[tuple[int, ...], tuple[int, ...]]:
45+
"""Normalize the axes specification for convolution."""
46+
47+
if axes is None:
48+
if ndim_array != ndim_kernel:
49+
raise ValueError(
50+
"Kernel dimensions must match array dimensions when 'axes' is not provided, "
51+
f"got array ndim {ndim_array} and kernel ndim {ndim_kernel}."
52+
)
53+
axes_array = tuple(range(ndim_array))
54+
axes_kernel = tuple(range(ndim_kernel))
55+
return axes_array, axes_kernel
56+
57+
if len(axes) != 2:
58+
raise ValueError("'axes' must be a tuple of two iterable collections of axis indices.")
59+
60+
axes_array_raw, axes_kernel_raw = axes
61+
62+
axes_array = tuple((ax + ndim_array) % ndim_array for ax in axes_array_raw)
63+
axes_kernel = tuple((ax + ndim_kernel) % ndim_kernel for ax in axes_kernel_raw)
64+
65+
if len(axes_array) != len(axes_kernel):
66+
raise ValueError(
67+
"The number of convolution axes for the array and kernel must be the same, "
68+
f"got {len(axes_array)} and {len(axes_kernel)}."
69+
)
70+
71+
if len(set(axes_array)) != len(axes_array) or len(set(axes_kernel)) != len(axes_kernel):
72+
raise ValueError("Convolution axes must be unique for both the array and the kernel.")
73+
74+
if any(ax < 0 or ax >= ndim_array for ax in axes_array):
75+
raise ValueError(
76+
f"Array axes out of bounds for array with {ndim_array} dimensions: {axes_array}."
77+
)
78+
79+
if any(ax < 0 or ax >= ndim_kernel for ax in axes_kernel):
80+
raise ValueError(
81+
f"Kernel axes out of bounds for kernel with {ndim_kernel} dimensions: {axes_kernel}."
82+
)
83+
84+
return axes_array, axes_kernel
85+
86+
87+
def _fft_convolve_general(
88+
array: NDArray,
89+
kernel: NDArray,
90+
axes_array: tuple[int, ...],
91+
axes_kernel: tuple[int, ...],
92+
mode: Literal["full", "valid"],
93+
) -> NDArray:
94+
"""Perform convolution using FFT along the specified axes."""
95+
96+
num_conv_axes = len(axes_array)
97+
98+
if num_conv_axes == 0:
99+
array_shape = array.shape
100+
kernel_shape = kernel.shape
101+
result = np.multiply(
102+
array.reshape(array_shape + (1,) * kernel.ndim),
103+
kernel.reshape((1,) * array.ndim + kernel_shape),
104+
)
105+
return result.reshape(array_shape + kernel_shape)
106+
107+
ignore_axes_array = tuple(ax for ax in range(array.ndim) if ax not in axes_array)
108+
ignore_axes_kernel = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel)
109+
110+
new_order_array = ignore_axes_array + axes_array
111+
new_order_kernel = ignore_axes_kernel + axes_kernel
112+
113+
array_reordered = np.transpose(array, new_order_array) if array.ndim else array
114+
kernel_reordered = np.transpose(kernel, new_order_kernel) if kernel.ndim else kernel
115+
116+
num_batch_array = len(ignore_axes_array)
117+
num_batch_kernel = len(ignore_axes_kernel)
118+
119+
array_batch_shape = array_reordered.shape[:num_batch_array]
120+
kernel_batch_shape = kernel_reordered.shape[:num_batch_kernel]
121+
122+
array_conv_shape = array_reordered.shape[num_batch_array:]
123+
kernel_conv_shape = kernel_reordered.shape[num_batch_kernel:]
124+
125+
array_expand_shape = array_batch_shape + (1,) * num_batch_kernel + array_conv_shape
126+
kernel_expand_shape = (1,) * num_batch_array + kernel_batch_shape + kernel_conv_shape
127+
128+
array_expanded = np.reshape(array_reordered, array_expand_shape)
129+
kernel_expanded = np.reshape(kernel_reordered, kernel_expand_shape)
130+
131+
fft_axes = tuple(range(-num_conv_axes, 0))
132+
fft_shape = tuple(
133+
int(array_dim + kernel_dim - 1)
134+
for array_dim, kernel_dim in zip(array_conv_shape, kernel_conv_shape)
135+
)
136+
137+
array_fft = fftn(array_expanded, fft_shape, axes=fft_axes)
138+
kernel_fft = fftn(kernel_expanded, fft_shape, axes=fft_axes)
139+
full_result = ifftn(array_fft * kernel_fft, fft_shape, axes=fft_axes)
140+
141+
if mode == "full":
142+
result = full_result
143+
elif mode == "valid":
144+
valid_slices = [slice(None)] * full_result.ndim
145+
for axis_offset, (array_dim, kernel_dim) in enumerate(
146+
zip(array_conv_shape, kernel_conv_shape)
147+
):
148+
start = int(min(array_dim, kernel_dim) - 1)
149+
length = int(abs(array_dim - kernel_dim) + 1)
150+
axis = full_result.ndim - num_conv_axes + axis_offset
151+
valid_slices[axis] = slice(start, start + length)
152+
result = full_result[tuple(valid_slices)]
153+
else:
154+
raise ValueError(f"Unsupported convolution mode '{mode}'.")
155+
156+
if not np.iscomplexobj(array) and not np.iscomplexobj(kernel):
157+
result = np.real(result)
158+
159+
return result
160+
161+
40162
def _get_pad_indices(
41163
n: int,
42164
pad_width: tuple[int, int],
@@ -189,19 +311,21 @@ def convolve(
189311
if any(k % 2 == 0 for k in kernel.shape):
190312
raise ValueError(f"All kernel dimensions must be odd, got {kernel.shape}.")
191313

192-
if kernel.ndim != array.ndim and axes is None:
193-
raise ValueError(
194-
f"Kernel dimensions must match array dimensions, got kernel {kernel.shape} and array {array.shape}."
195-
)
314+
axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes)
196315

197-
if mode in ("same", "full"):
198-
kernel_dims = kernel.shape if axes is None else [kernel.shape[d] for d in axes[1]]
199-
pad_widths = [(ks // 2, ks // 2) for ks in kernel_dims]
200-
for axis, pad_width in enumerate(pad_widths):
201-
array = pad(array, pad_width, mode=padding, axis=axis)
202-
mode = "valid" if mode == "same" else mode
316+
working_array = array
317+
effective_mode = mode
203318

204-
return convolve_ag(array, kernel, axes=axes, mode=mode)
319+
if mode in ("same", "full"):
320+
for ax_array, ax_kernel in zip(axes_array, axes_kernel):
321+
pad_width = kernel.shape[ax_kernel] // 2
322+
if pad_width > 0:
323+
working_array = pad(
324+
working_array, (pad_width, pad_width), mode=padding, axis=ax_array
325+
)
326+
effective_mode = "valid" if mode == "same" else mode
327+
328+
return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode)
205329

206330

207331
def _get_footprint(size, structure, maxval):

0 commit comments

Comments
 (0)