|
7 | 7 | import numpy as onp |
8 | 8 | from autograd import jacobian |
9 | 9 | from autograd.extend import defvjp, primitive |
10 | | -from autograd.scipy.signal import convolve as convolve_ag |
| 10 | +from autograd.numpy.fft import fftn, ifftn |
11 | 11 | from autograd.scipy.special import logsumexp |
12 | 12 | from autograd.tracer import getval |
13 | 13 | from numpy.lib.stride_tricks import sliding_window_view |
|
37 | 37 | ] |
38 | 38 |
|
39 | 39 |
|
| 40 | +def _normalize_axes( |
| 41 | + ndim_array: int, |
| 42 | + ndim_kernel: int, |
| 43 | + axes: Union[tuple[Iterable[int], Iterable[int]], None], |
| 44 | +) -> tuple[tuple[int, ...], tuple[int, ...]]: |
| 45 | + """Normalize the axes specification for convolution.""" |
| 46 | + |
| 47 | + def _normalize_single_axis(ax: int, ndim: int, kind: str) -> int: |
| 48 | + if not isinstance(ax, int): |
| 49 | + raise TypeError( |
| 50 | + f"Axis indices must be integers, got {ax!r} of type {type(ax).__name__}." |
| 51 | + ) |
| 52 | + |
| 53 | + if not -ndim <= ax < ndim: |
| 54 | + raise ValueError(f"Invalid axis {ax} for {kind} with ndim {ndim}.") |
| 55 | + return ax + ndim if ax < 0 else ax |
| 56 | + |
| 57 | + if axes is None: |
| 58 | + if ndim_array != ndim_kernel: |
| 59 | + raise ValueError( |
| 60 | + "Kernel dimensions must match array dimensions when 'axes' is not provided, " |
| 61 | + f"got array ndim {ndim_array} and kernel ndim {ndim_kernel}." |
| 62 | + ) |
| 63 | + axes_array = tuple(range(ndim_array)) |
| 64 | + axes_kernel = tuple(range(ndim_kernel)) |
| 65 | + return axes_array, axes_kernel |
| 66 | + |
| 67 | + if len(axes) != 2: |
| 68 | + raise ValueError("'axes' must be a tuple of two iterable collections of axis indices.") |
| 69 | + |
| 70 | + axes_array_raw, axes_kernel_raw = axes |
| 71 | + |
| 72 | + axes_array = tuple(_normalize_single_axis(ax, ndim_array, "array") for ax in axes_array_raw) |
| 73 | + axes_kernel = tuple(_normalize_single_axis(ax, ndim_kernel, "kernel") for ax in axes_kernel_raw) |
| 74 | + |
| 75 | + if len(axes_array) != len(axes_kernel): |
| 76 | + raise ValueError( |
| 77 | + "The number of convolution axes for the array and kernel must be the same, " |
| 78 | + f"got {len(axes_array)} and {len(axes_kernel)}." |
| 79 | + ) |
| 80 | + |
| 81 | + if len(set(axes_array)) != len(axes_array) or len(set(axes_kernel)) != len(axes_kernel): |
| 82 | + raise ValueError("Convolution axes must be unique for both the array and the kernel.") |
| 83 | + |
| 84 | + return axes_array, axes_kernel |
| 85 | + |
| 86 | + |
| 87 | +def _fft_convolve_general( |
| 88 | + array: NDArray, |
| 89 | + kernel: NDArray, |
| 90 | + axes_array: tuple[int, ...], |
| 91 | + axes_kernel: tuple[int, ...], |
| 92 | + mode: Literal["full", "valid"], |
| 93 | +) -> NDArray: |
| 94 | + """Perform convolution using FFT along the specified axes.""" |
| 95 | + |
| 96 | + num_conv_axes = len(axes_array) |
| 97 | + |
| 98 | + if num_conv_axes == 0: |
| 99 | + array_shape = array.shape |
| 100 | + kernel_shape = kernel.shape |
| 101 | + result = np.multiply( |
| 102 | + array.reshape(array_shape + (1,) * kernel.ndim), |
| 103 | + kernel.reshape((1,) * array.ndim + kernel_shape), |
| 104 | + ) |
| 105 | + return result.reshape(array_shape + kernel_shape) |
| 106 | + |
| 107 | + ignore_axes_array = tuple(ax for ax in range(array.ndim) if ax not in axes_array) |
| 108 | + ignore_axes_kernel = tuple(ax for ax in range(kernel.ndim) if ax not in axes_kernel) |
| 109 | + |
| 110 | + new_order_array = ignore_axes_array + axes_array |
| 111 | + new_order_kernel = ignore_axes_kernel + axes_kernel |
| 112 | + |
| 113 | + array_reordered = np.transpose(array, new_order_array) if array.ndim else array |
| 114 | + kernel_reordered = np.transpose(kernel, new_order_kernel) if kernel.ndim else kernel |
| 115 | + |
| 116 | + num_batch_array = len(ignore_axes_array) |
| 117 | + num_batch_kernel = len(ignore_axes_kernel) |
| 118 | + |
| 119 | + array_conv_shape = array_reordered.shape[num_batch_array:] |
| 120 | + kernel_conv_shape = kernel_reordered.shape[num_batch_kernel:] |
| 121 | + |
| 122 | + if any(d <= 0 for d in array_conv_shape + kernel_conv_shape): |
| 123 | + raise ValueError("Convolution dimensions must be positive; got zero-length axis.") |
| 124 | + |
| 125 | + fft_axes = tuple(range(-num_conv_axes, 0)) |
| 126 | + fft_shape = tuple( |
| 127 | + int(array_dim + kernel_dim - 1) |
| 128 | + for array_dim, kernel_dim in zip(array_conv_shape, kernel_conv_shape) |
| 129 | + ) |
| 130 | + |
| 131 | + array_fft = fftn(array_reordered, fft_shape, axes=fft_axes) |
| 132 | + kernel_fft = fftn(kernel_reordered, fft_shape, axes=fft_axes) |
| 133 | + |
| 134 | + if num_batch_kernel: |
| 135 | + array_batch_shape = array_fft.shape[:num_batch_array] |
| 136 | + conv_shape = array_fft.shape[num_batch_array:] |
| 137 | + array_fft = np.reshape( |
| 138 | + array_fft, |
| 139 | + array_batch_shape + (1,) * num_batch_kernel + conv_shape, |
| 140 | + ) |
| 141 | + |
| 142 | + if num_batch_array: |
| 143 | + kernel_batch_shape = kernel_fft.shape[:num_batch_kernel] |
| 144 | + conv_shape = kernel_fft.shape[num_batch_kernel:] |
| 145 | + kernel_fft = np.reshape( |
| 146 | + kernel_fft, |
| 147 | + (1,) * num_batch_array + kernel_batch_shape + conv_shape, |
| 148 | + ) |
| 149 | + |
| 150 | + full_result = ifftn(array_fft * kernel_fft, fft_shape, axes=fft_axes) |
| 151 | + |
| 152 | + if mode == "full": |
| 153 | + result = full_result |
| 154 | + elif mode == "valid": |
| 155 | + valid_slices = [slice(None)] * full_result.ndim |
| 156 | + for axis_offset, (array_dim, kernel_dim) in enumerate( |
| 157 | + zip(array_conv_shape, kernel_conv_shape) |
| 158 | + ): |
| 159 | + start = int(min(array_dim, kernel_dim) - 1) |
| 160 | + length = int(abs(array_dim - kernel_dim) + 1) |
| 161 | + axis = full_result.ndim - num_conv_axes + axis_offset |
| 162 | + valid_slices[axis] = slice(start, start + length) |
| 163 | + result = full_result[tuple(valid_slices)] |
| 164 | + else: |
| 165 | + raise ValueError(f"Unsupported convolution mode '{mode}'.") |
| 166 | + |
| 167 | + if not np.iscomplexobj(array) and not np.iscomplexobj(kernel): |
| 168 | + result = np.real(result) |
| 169 | + |
| 170 | + return result |
| 171 | + |
| 172 | + |
40 | 173 | def _get_pad_indices( |
41 | 174 | n: int, |
42 | 175 | pad_width: tuple[int, int], |
@@ -180,19 +313,27 @@ def convolve( |
180 | 313 | if any(k % 2 == 0 for k in kernel.shape): |
181 | 314 | raise ValueError(f"All kernel dimensions must be odd, got {kernel.shape}.") |
182 | 315 |
|
183 | | - if kernel.ndim != array.ndim and axes is None: |
184 | | - raise ValueError( |
185 | | - f"Kernel dimensions must match array dimensions, got kernel {kernel.shape} and array {array.shape}." |
186 | | - ) |
187 | | - |
188 | | - if mode in ("same", "full"): |
189 | | - kernel_dims = kernel.shape if axes is None else [kernel.shape[d] for d in axes[1]] |
190 | | - pad_widths = [(ks // 2, ks // 2) for ks in kernel_dims] |
191 | | - for axis, pad_width in enumerate(pad_widths): |
192 | | - array = pad(array, pad_width, mode=padding, axis=axis) |
193 | | - mode = "valid" if mode == "same" else mode |
194 | | - |
195 | | - return convolve_ag(array, kernel, axes=axes, mode=mode) |
| 316 | + axes_array, axes_kernel = _normalize_axes(array.ndim, kernel.ndim, axes) |
| 317 | + |
| 318 | + working_array = array |
| 319 | + effective_mode = mode |
| 320 | + |
| 321 | + if mode in ["same", "full"]: |
| 322 | + for ax_array, ax_kernel in zip(axes_array, axes_kernel): |
| 323 | + if mode == "same" and kernel.shape[ax_kernel] % 2 == 0: |
| 324 | + raise ValueError( |
| 325 | + f"Even-sized kernel along axis {ax_kernel} not supported for 'same' mode." |
| 326 | + ) |
| 327 | + pad_width = ( |
| 328 | + kernel.shape[ax_kernel] // 2 if mode == "same" else kernel.shape[ax_kernel] - 1 |
| 329 | + ) |
| 330 | + if pad_width > 0: |
| 331 | + working_array = pad( |
| 332 | + working_array, (pad_width, pad_width), mode=padding, axis=ax_array |
| 333 | + ) |
| 334 | + effective_mode = "valid" |
| 335 | + |
| 336 | + return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode) |
196 | 337 |
|
197 | 338 |
|
198 | 339 | def _get_footprint(size, structure, maxval): |
|
0 commit comments