|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from collections.abc import Iterable |
4 | | -from typing import Callable, Literal, Union |
| 4 | +from typing import Callable, Literal, SupportsInt, Union |
5 | 5 |
|
6 | 6 | import autograd.numpy as np |
7 | 7 | import numpy as onp |
8 | 8 | from autograd import jacobian |
9 | 9 | from autograd.extend import defvjp, primitive |
10 | | -from autograd.scipy.signal import convolve as convolve_ag |
| 10 | +from autograd.numpy.fft import fftn, ifftn |
11 | 11 | from autograd.scipy.special import logsumexp |
12 | 12 | from autograd.tracer import getval |
| 13 | +from numpy.fft import irfftn, rfftn |
13 | 14 | from numpy.lib.stride_tricks import sliding_window_view |
14 | 15 | from numpy.typing import NDArray |
| 16 | +from scipy.fft import next_fast_len |
15 | 17 |
|
16 | 18 | from tidy3d.components.autograd.functions import add_at, interpn, trapz |
17 | 19 |
|
|
37 | 39 | ] |
38 | 40 |
|
39 | 41 |
|
| 42 | +def _normalize_axes( |
| 43 | + ndim_array: int, |
| 44 | + ndim_kernel: int, |
| 45 | + axes: Union[tuple[Iterable[SupportsInt], Iterable[SupportsInt]], None], |
| 46 | +) -> tuple[tuple[int, ...], tuple[int, ...]]: |
| 47 | + """Normalize the axes specification for convolution.""" |
| 48 | + |
| 49 | + def _normalize_single_axis(ax: SupportsInt, ndim: int, kind: str) -> int: |
| 50 | + if not isinstance(ax, int): |
| 51 | + try: |
| 52 | + ax = int(ax) |
| 53 | + except Exception as e: |
| 54 | + raise TypeError(f"Axis {ax!r} could not be converted to an integer.") from e |
| 55 | + |
| 56 | + if not -ndim <= ax < ndim: |
| 57 | + raise ValueError(f"Invalid axis {ax} for {kind} with ndim {ndim}.") |
| 58 | + return ax + ndim if ax < 0 else ax |
| 59 | + |
| 60 | + if axes is None: |
| 61 | + if ndim_array != ndim_kernel: |
| 62 | + raise ValueError( |
| 63 | + "Kernel dimensions must match array dimensions when 'axes' is not provided, " |
| 64 | + f"got array ndim {ndim_array} and kernel ndim {ndim_kernel}." |
| 65 | + ) |
| 66 | + axes_array = tuple(range(ndim_array)) |
| 67 | + axes_kernel = tuple(range(ndim_kernel)) |
| 68 | + return axes_array, axes_kernel |
| 69 | + |
| 70 | + if len(axes) != 2: |
| 71 | + raise ValueError("'axes' must be a tuple of two iterable collections of axis indices.") |
| 72 | + |
| 73 | + axes_array_raw, axes_kernel_raw = axes |
| 74 | + |
| 75 | + axes_array = tuple(_normalize_single_axis(ax, ndim_array, "array") for ax in axes_array_raw) |
| 76 | + axes_kernel = tuple(_normalize_single_axis(ax, ndim_kernel, "kernel") for ax in axes_kernel_raw) |
| 77 | + |
| 78 | + if len(axes_array) != len(axes_kernel): |
| 79 | + raise ValueError( |
| 80 | + "The number of convolution axes for the array and kernel must be the same, " |
| 81 | + f"got {len(axes_array)} and {len(axes_kernel)}." |
| 82 | + ) |
| 83 | + |
| 84 | + if len(set(axes_array)) != len(axes_array) or len(set(axes_kernel)) != len(axes_kernel): |
| 85 | + raise ValueError("Convolution axes must be unique for both the array and the kernel.") |
| 86 | + |
| 87 | + return axes_array, axes_kernel |
| 88 | + |
| 89 | + |
| 90 | +def _fft_convolve_general( |
| 91 | + array: NDArray, |
| 92 | + kernel: NDArray, |
| 93 | + axes_array: tuple[int, ...], |
| 94 | + axes_kernel: tuple[int, ...], |
| 95 | + mode: Literal["full", "valid"], |
| 96 | +) -> NDArray: |
| 97 | + """Perform convolution using FFT along the specified axes.""" |
| 98 | + |
| 99 | + num_conv_axes = len(axes_array) |
| 100 | + |
| 101 | + if num_conv_axes == 0: |
| 102 | + array_shape = array.shape |
| 103 | + kernel_shape = kernel.shape |
| 104 | + result = np.multiply( |
| 105 | + array.reshape(array_shape + (1,) * kernel.ndim), |
| 106 | + kernel.reshape((1,) * array.ndim + kernel_shape), |
| 107 | + ) |
| 108 | + return result.reshape(array_shape + kernel_shape) |
| 109 | + |
| 110 | + ignore_axes_array = tuple(ax for ax in range(array.ndim) if ax not in axes_array) |
| 111 | + ignore_axes_kernel = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel) |
| 112 | + |
| 113 | + new_order_array = ignore_axes_array + axes_array |
| 114 | + new_order_kernel = ignore_axes_kernel + axes_kernel |
| 115 | + |
| 116 | + array_reordered = np.transpose(array, new_order_array) if array.ndim else array |
| 117 | + kernel_reordered = np.transpose(kernel, new_order_kernel) if kernel.ndim else kernel |
| 118 | + |
| 119 | + num_batch_array = len(ignore_axes_array) |
| 120 | + num_batch_kernel = len(ignore_axes_kernel) |
| 121 | + |
| 122 | + array_conv_shape = array_reordered.shape[num_batch_array:] |
| 123 | + kernel_conv_shape = kernel_reordered.shape[num_batch_kernel:] |
| 124 | + |
| 125 | + if any(d <= 0 for d in array_conv_shape + kernel_conv_shape): |
| 126 | + raise ValueError("Convolution dimensions must be positive; got zero-length axis.") |
| 127 | + |
| 128 | + fft_axes = tuple(range(-num_conv_axes, 0)) |
| 129 | + fft_shape = [next_fast_len(n + k - 1) for n, k in zip(array_conv_shape, kernel_conv_shape)] |
| 130 | + use_real_fft = fft_shape[-1] % 2 == 0 # only applicable in this case |
| 131 | + |
| 132 | + fft_fun = rfftn if use_real_fft else fftn |
| 133 | + array_fft = fft_fun(array_reordered, fft_shape, axes=fft_axes) |
| 134 | + kernel_fft = fft_fun(kernel_reordered, fft_shape, axes=fft_axes) |
| 135 | + |
| 136 | + if num_batch_kernel: |
| 137 | + array_batch_shape = array_fft.shape[:num_batch_array] |
| 138 | + conv_shape = array_fft.shape[num_batch_array:] |
| 139 | + array_fft = np.reshape( |
| 140 | + array_fft, |
| 141 | + array_batch_shape + (1,) * num_batch_kernel + conv_shape, |
| 142 | + ) |
| 143 | + |
| 144 | + if num_batch_array: |
| 145 | + kernel_batch_shape = kernel_fft.shape[:num_batch_kernel] |
| 146 | + conv_shape = kernel_fft.shape[num_batch_kernel:] |
| 147 | + kernel_fft = np.reshape( |
| 148 | + kernel_fft, |
| 149 | + (1,) * num_batch_array + kernel_batch_shape + conv_shape, |
| 150 | + ) |
| 151 | + use_real_fft = fft_shape[-1] % 2 == 0 |
| 152 | + |
| 153 | + product = array_fft * kernel_fft |
| 154 | + |
| 155 | + ifft_fun = irfftn if use_real_fft else ifftn |
| 156 | + full_result = ifft_fun(product, fft_shape, axes=fft_axes) |
| 157 | + |
| 158 | + if mode == "full": |
| 159 | + result = full_result |
| 160 | + elif mode == "valid": |
| 161 | + valid_slices = [slice(None)] * full_result.ndim |
| 162 | + for axis_offset, (array_dim, kernel_dim) in enumerate( |
| 163 | + zip(array_conv_shape, kernel_conv_shape) |
| 164 | + ): |
| 165 | + start = int(min(array_dim, kernel_dim) - 1) |
| 166 | + length = int(abs(array_dim - kernel_dim) + 1) |
| 167 | + axis = full_result.ndim - num_conv_axes + axis_offset |
| 168 | + valid_slices[axis] = slice(start, start + length) |
| 169 | + result = full_result[tuple(valid_slices)] |
| 170 | + else: |
| 171 | + raise ValueError(f"Unsupported convolution mode '{mode}'.") |
| 172 | + |
| 173 | + return np.real(result) |
| 174 | + |
| 175 | + |
40 | 176 | def _get_pad_indices( |
41 | 177 | n: int, |
42 | 178 | pad_width: tuple[int, int], |
@@ -148,7 +284,7 @@ def convolve( |
148 | 284 | kernel: NDArray, |
149 | 285 | *, |
150 | 286 | padding: PaddingType = "constant", |
151 | | - axes: Union[tuple[list[int], list[int]], None] = None, |
| 287 | + axes: Union[tuple[list[SupportsInt], list[SupportsInt]], None] = None, |
152 | 288 | mode: Literal["full", "valid", "same"] = "same", |
153 | 289 | ) -> NDArray: |
154 | 290 | """Convolve an array with a given kernel. |
@@ -180,19 +316,23 @@ def convolve( |
180 | 316 | if any(k % 2 == 0 for k in kernel.shape): |
181 | 317 | raise ValueError(f"All kernel dimensions must be odd, got {kernel.shape}.") |
182 | 318 |
|
183 | | - if kernel.ndim != array.ndim and axes is None: |
184 | | - raise ValueError( |
185 | | - f"Kernel dimensions must match array dimensions, got kernel {kernel.shape} and array {array.shape}." |
186 | | - ) |
| 319 | + axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes) |
| 320 | + |
| 321 | + working_array = array |
| 322 | + effective_mode = mode |
187 | 323 |
|
188 | | - if mode in ("same", "full"): |
189 | | - kernel_dims = kernel.shape if axes is None else [kernel.shape[d] for d in axes[1]] |
190 | | - pad_widths = [(ks // 2, ks // 2) for ks in kernel_dims] |
191 | | - for axis, pad_width in enumerate(pad_widths): |
192 | | - array = pad(array, pad_width, mode=padding, axis=axis) |
193 | | - mode = "valid" if mode == "same" else mode |
| 324 | + if mode in ["same", "full"]: |
| 325 | + for ax_array, ax_kernel in zip(axes_array, axes_kernel): |
| 326 | + pad_width = ( |
| 327 | + kernel.shape[ax_kernel] // 2 if mode == "same" else kernel.shape[ax_kernel] - 1 |
| 328 | + ) |
| 329 | + if pad_width > 0: |
| 330 | + working_array = pad( |
| 331 | + working_array, (pad_width, pad_width), mode=padding, axis=ax_array |
| 332 | + ) |
| 333 | + effective_mode = "valid" |
194 | 334 |
|
195 | | - return convolve_ag(array, kernel, axes=axes, mode=mode) |
| 335 | + return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode) |
196 | 336 |
|
197 | 337 |
|
198 | 338 | def _get_footprint(size, structure, maxval): |
|
0 commit comments