Skip to content

Implement Convolve2D Op #1397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 137 additions & 1 deletion pytensor/tensor/signal/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import numpy as np
from numpy import convolve as numpy_convolve
from scipy.signal import convolve2d as scipy_convolve2d

from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply, Constant
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.scalar import as_scalar
from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum, switch
from pytensor.tensor.type import vector
from pytensor.tensor.type import matrix, vector
from pytensor.tensor.variable import TensorVariable


Expand Down Expand Up @@ -211,3 +213,137 @@ def convolve1d(

full_mode = as_scalar(np.bool_(mode == "full"))
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))


class Convolve2D(Op):
__props__ = ("mode", "boundary", "fillvalue")
gufunc_signature = "(n,m),(k,l)->(o,p)"

def __init__(
self,
mode: Literal["full", "valid"] = "full",
boundary: Literal["fill", "wrap", "symm"] = "fill",
fillvalue: float | int = 0,
):
if mode not in ("full", "valid"):
raise ValueError(f"Invalid mode: {mode}")

self.mode = mode
self.fillvalue = fillvalue
self.boundary = boundary

def make_node(self, in1, in2):
in1, in2 = map(as_tensor_variable, (in1, in2))

assert in1.ndim == 2
assert in2.ndim == 2

dtype = upcast(in1.dtype, in2.dtype)

n, m = in1.type.shape
k, l = in2.type.shape

if self.mode == "full":
shape_1 = None if (n is None or k is None) else n + k - 1
shape_2 = None if (m is None or l is None) else m + l - 1

elif self.mode == "valid":
shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1
shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1

else: # mode == "same"
shape_1 = n
shape_2 = m

out_shape = (shape_1, shape_2)
out = matrix(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2], [out])

def perform(self, node, inputs, outputs):
in1, in2 = inputs

# if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
# outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
#
# else:
outputs[0][0] = scipy_convolve2d(
in1, in2, mode=self.mode, fillvalue=self.fillvalue, boundary=self.boundary
)

def infer_shape(self, fgraph, node, shapes):
in1_shape, in2_shape = shapes
n, m = in1_shape
k, l = in2_shape

if self.mode == "full":
shape = (n + k - 1, m + l - 1)
elif self.mode == "valid":
shape = (
maximum(n, k) - minimum(n, k) + 1,
maximum(m, l) - minimum(m, l) + 1,
)
else: # self.mode == 'same':
shape = (n, m)

return [shape]

def L_op(self, inputs, outputs, output_grads):
in1, in2 = inputs
incoming_grads = output_grads[0]

if self.mode == "full":
prop_dict = self._props_dict()
prop_dict["mode"] = "valid"
conv_valid = type(self)(**prop_dict)

in1_grad = conv_valid(in2, incoming_grads)
in2_grad = conv_valid(in1, incoming_grads)

return [in1_grad, in2_grad]


def convolve2d(
in1: "TensorLike",
in2: "TensorLike",
mode: Literal["full", "valid", "same"] = "full",
boundary: Literal["fill", "wrap", "symm"] = "fill",
fillvalue: float | int = 0,
) -> TensorVariable:
"""Convolve two two-dimensional arrays.

Convolve in1 and in2, with the output size determined by the mode argument.

Parameters
----------
in1 : (..., N, M) tensor_like
First input.
in2 : (..., K, L) tensor_like
Second input.
mode : {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
boundary : {'fill', 'wrap', 'symm'}, optional
A string indicating how to handle boundaries:
- 'fill': Pads the input arrays with fillvalue.
- 'wrap': Circularly wraps the input arrays.
- 'symm': Symmetrically reflects the input arrays.
fillvalue : float or int, optional
The value to use for padding when boundary is 'fill'. Default is 0.
Returns
-------
out: tensor_variable
The discrete linear convolution of in1 with in2.

"""
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)

# TODO: Handle boundaries symbolically
# TODO: Handle 'same' symbolically

blockwise_convolve = Blockwise(
Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue)
)
return cast(TensorVariable, blockwise_convolve(in1, in2))
61 changes: 60 additions & 1 deletion tests/tensor/signal/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import numpy as np
import pytest
from scipy.signal import convolve as scipy_convolve
from scipy.signal import convolve2d as scipy_convolve2d

from pytensor import config, function, grad
from pytensor.graph.basic import ancestors, io_toposort
from pytensor.graph.rewriting import rewrite_graph
from pytensor.tensor import matrix, tensor, vector
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d
from tests import unittest_tools as utt


Expand Down Expand Up @@ -137,3 +138,61 @@ def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark):
@pytest.mark.parametrize("convolve_mode", ["full", "valid"])
def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark):
convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark)


@pytest.mark.parametrize(
"kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}"
)
@pytest.mark.parametrize(
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
)
Comment on lines +143 to +148
Copy link
Member

@ricardoV94 ricardoV94 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like a parametrization where one of the dimensions is larger and the other smaller than the respective dimensions of the other input, something like (5, 5) vs (3, 7) (with both input orders). Specially for the grad. This is something that cannot happeen in Conv1D and I want to be sure we are doing it correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in that case we can swap the inputs then swap them back?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not if you only have runtime shapes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scipy will do this swap internally when mode="valid", see here. This helper is called by both convolve and convolve2d.

Our gradients will be wrong if we don't take that into account.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I hate that, it should be an implementation detail under the hood and not affect us though?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think numpy convolve also does it and it's not a problem for us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's "luck", because the L_op calls self.perform so gradient inputs will also be flipped.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Luck? Wait aren't we just talking about the same issue that's addressed by #1522 (and before that by doing the worst case scenario pad and throw away the waste)?

Copy link
Member Author

@jessegrabowski jessegrabowski Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. My point is that the shapes of the gradient inputs are the same as the shapes of the inputs, and we call the same function again. So the flipping/unflipping is done for us correctly in both cases. If the gradient for conv1d didn't end up itself being a convolution, we would have had problems.

@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("boundary", ["fill", "wrap", "symm"])
def test_convolve2d(kernel_shape, data_shape, mode, boundary):
data = matrix("data")
kernel = matrix("kernel")
op = partial(convolve2d, mode=mode, boundary=boundary, fillvalue=0)

rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
data_val = rng.normal(size=data_shape).astype(data.dtype)
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)

fn = function([data, kernel], op(data, kernel))
np.testing.assert_allclose(
fn(data_val, kernel_val),
scipy_convolve2d(
data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0
),
atol=1e-6 if config.floatX == "float32" else 1e-8,
)

utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val], eps=1e-4)


# @pytest.mark.parametrize(
# "data_shape, kernel_shape", [[(10, 1, 8, 8), (3, 1, 3, 3)], # 8x8 grayscale
# [(1000, 1, 8, 8), (3, 1, 1, 3)], # same, but with 1000 images
# [(10, 3, 64, 64), (10, 3, 8, 8)], # 64x64 RGB
# [(1000, 3, 64, 64), (10, 3, 8, 8)], # same, but with 1000 images
# [(3, 100, 100, 100), (250, 100, 50, 50)]], # Very large, deep hidden layer or something
#
# ids=lambda x: f"data_shape={x[0]}, kernel_shape={x[1]}"
# )
# @pytest.mark.parametrize('func', ['new', 'theano'], ids=['new-impl', 'theano-impl'])
# def test_conv2d_nn_benchmark(data_shape, kernel_shape, func, benchmark):
# import pytensor.tensor as pt
# x = pt.tensor("x", shape=data_shape)
# y = pt.tensor("y", shape=kernel_shape)
#
# if func == 'new':
# out = nn_conv2d(x, y)
# else:
# out = conv2d(input=x, filters=y, border_mode="valid")
#
# rng = np.random.default_rng(38)
# x_test = rng.normal(size=data_shape).astype(x.dtype)
# y_test = rng.normal(size=kernel_shape).astype(y.dtype)
#
# fn = function([x, y], out, trust_input=True)
#
# benchmark(fn, x_test, y_test)
Loading