Skip to content

Make convolve mode symbolic to avoid unnecessary large convolution in gradient graph #1522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 8, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 4, 2025

Instead of statically parametrizing the convolution type at the Op level, it now uses a scalar boolean that can be set symbolically. This was fine except for JAX, where we can't branch symbolically like that.

As long as the mode is constant (which is the case when the original convolve was full, or the core shapes are statically known) this should be fine, except for one hiccup. In the dispatch of Blockwise we couldn't see the outer inputs.

I added some functionality when we create the dummy core node to propagate inputs if these don't have batch dimensions, which means they won't change over iterations and so making compile or rewrite decisions based on this should be safe. This can also be used for infer_shape for instance, which could help with lowering some Ops to numba

These changes would also allow us to compile a constant convolve mode in C/Numba, but benchmarks didn't show any gains so I didn't bother doing that. In any case, future implementations of Blockwise for certain Ops can make use of that information.

Relevant benchmark tests:

Before:
--------------------------------------------------------------------------- benchmark: 8 tests ---------------------------------------------------------------------------
Name (time in us)                                                   Min                    Max                   Mean              StdDev                 Median
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_convolve1d_benchmark_numba[batch=False-mode=full]           6.3820 (1.06)        117.5200 (1.72)          6.7371 (1.06)       1.0520 (1.27)          6.5720 (1.05)
test_convolve1d_benchmark_numba[batch=False-mode=valid]          6.0410 (1.0)          81.1920 (1.19)          6.3736 (1.0)        0.8259 (1.0)           6.2410 (1.0)
test_convolve1d_benchmark_numba[batch=True-mode=full]           14.0660 (2.33)      1,015.9230 (14.87)        14.6380 (2.30)       5.4193 (6.56)         14.3370 (2.30)
test_convolve1d_benchmark_numba[batch=True-mode=valid]          12.8040 (2.12)         68.3380 (1.0)          13.4789 (2.11)       1.4135 (1.71)         13.0740 (2.09)
test_convolve1d_grad_benchmark_numba[full]                     176.2610 (29.18)       190.4460 (2.79)        181.4098 (28.46)      6.4607 (7.82)        177.5830 (28.45)
test_convolve1d_grad_benchmark_numba[valid]                 10,654.1110 (>1000.0)  10,663.5190 (156.04)   10,658.8200 (>1000.0)    3.5995 (4.36)     10,658.2390 (>1000.0)
test_convolve1d_grad_benchmark_c[full]                          99.6770 (16.50)       188.2430 (2.75)        107.0415 (16.79)      8.2845 (10.03)       103.7550 (16.62)
test_convolve1d_grad_benchmark_c[valid]                      1,945.9860 (322.13)    3,263.2340 (47.75)     2,070.7502 (324.90)   165.5805 (200.49)    1,991.7020 (319.13)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After:
--------------------------------------------------------------------- benchmark: 8 tests --------------------------------------------------------------------
Name (time in us)                                                Min                 Max                Mean             StdDev              Median
-------------------------------------------------------------------------------------------------------------------------------------------------------------
test_convolve1d_benchmark_numba[batch=False-mode=full]        6.4320 (1.05)     121.2880 (1.92)       7.0089 (1.08)      1.4235 (1.68)       6.6420 (1.05)
test_convolve1d_benchmark_numba[batch=False-mode=valid]       6.1220 (1.0)       63.2990 (1.0)        6.4999 (1.0)       0.8496 (1.0)        6.3120 (1.0)
test_convolve1d_benchmark_numba[batch=True-mode=full]        14.5480 (2.38)     133.0200 (2.10)      16.1110 (2.48)      2.5351 (2.98)      14.8580 (2.35)
test_convolve1d_benchmark_numba[batch=True-mode=valid]       13.4050 (2.19)      82.8950 (1.31)      14.2643 (2.19)      1.8216 (2.14)      13.7060 (2.17)
test_convolve1d_grad_benchmark_numba[full]                  177.7730 (29.04)    197.0690 (3.11)     183.0290 (28.16)     8.1456 (9.59)     179.0460 (28.37)
test_convolve1d_grad_benchmark_numba[valid]                 175.9900 (28.75)    184.9760 (2.92)     180.1314 (27.71)     4.4441 (5.23)     177.9230 (28.19)
test_convolve1d_grad_benchmark_c[full]                      107.2810 (17.52)    781.7850 (12.35)    122.0496 (18.78)    24.0425 (28.30)    112.4210 (17.81)
test_convolve1d_grad_benchmark_c[valid]                     115.8170 (18.92)    401.0020 (6.34)     127.0086 (19.54)    19.5353 (22.99)    120.4860 (19.09)

Note the worst case scenario when we were doing a full convolution for the smaller input in the gradient of a valid convolution.


📚 Documentation preview 📚: https://pytensor--1522.org.readthedocs.build/en/1522/

@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch 3 times, most recently from 48a5ce3 to 95b4cb3 Compare July 6, 2025 19:19
@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch 3 times, most recently from 281bbf9 to fe2ea6b Compare July 7, 2025 11:21
@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch 2 times, most recently from 12e3123 to 5be6968 Compare July 7, 2025 11:54
Copy link

codecov bot commented Jul 7, 2025

Codecov Report

Attention: Patch coverage is 68.42105% with 48 lines in your changes missing coverage. Please review.

Project coverage is 82.04%. Comparing base (7584614) to head (9026dd8).
Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/signal/conv.py 16.66% 35 Missing ⚠️
pytensor/tensor/signal/conv.py 80.48% 4 Missing and 4 partials ⚠️
pytensor/tensor/blockwise.py 93.75% 0 Missing and 3 partials ⚠️
pytensor/link/jax/dispatch/signal/conv.py 80.00% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1522      +/-   ##
==========================================
- Coverage   82.04%   82.04%   -0.01%     
==========================================
  Files         231      230       -1     
  Lines       52364    52346      -18     
  Branches     9217     9212       -5     
==========================================
- Hits        42962    42947      -15     
- Misses       7094     7095       +1     
+ Partials     2308     2304       -4     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/blockwise.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/blockwise.py 90.00% <100.00%> (ø)
pytensor/tensor/basic.py 91.69% <100.00%> (-0.02%) ⬇️
pytensor/tensor/rewriting/blockwise.py 96.55% <100.00%> (+0.47%) ⬆️
pytensor/tensor/rewriting/subtensor_lift.py 92.28% <100.00%> (+0.51%) ⬆️
pytensor/link/jax/dispatch/signal/conv.py 87.50% <80.00%> (-12.50%) ⬇️
pytensor/tensor/blockwise.py 89.31% <93.75%> (+0.04%) ⬆️
pytensor/tensor/signal/conv.py 87.50% <80.48%> (-7.63%) ⬇️
pytensor/link/numba/dispatch/signal/conv.py 32.69% <16.66%> (+0.69%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 marked this pull request as ready for review July 7, 2025 18:19
@ricardoV94 ricardoV94 requested a review from jessegrabowski July 7, 2025 18:19
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved with some nitpick comments

in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk]

return [in1_bar, in2_bar]
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really slick!

out[0] = s.flatten()[0]
def perform(self, node, inputs, output_storage):
# not using .item() because that returns a Python scalar, not a numpy scalar
output_storage[0][0] = inputs[0][()]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That [()] syntax is really ugly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's how you index/update a 0d array in numpy

Copy link
Member

@jessegrabowski jessegrabowski Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also use None I believe.

Edit: It only works for assignment. If this is really what they want you to do in this case, that sucks. But c'est la vie.

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is you're indexing an array without indices / dimensions, so you can't say pick first entry of first dimension (there are no dims). Anyway surprised it bothers you too much

@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch from 5be6968 to 9026dd8 Compare July 8, 2025 07:54
@ricardoV94 ricardoV94 merged commit 1d82fb4 into pymc-devs:main Jul 8, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants