-
Notifications
You must be signed in to change notification settings - Fork 137
Make convolve mode symbolic to avoid unnecessary large convolution in gradient graph #1522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make convolve mode symbolic to avoid unnecessary large convolution in gradient graph #1522
Conversation
48a5ce3
to
95b4cb3
Compare
281bbf9
to
fe2ea6b
Compare
12e3123
to
5be6968
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1522 +/- ##
==========================================
- Coverage 82.04% 82.04% -0.01%
==========================================
Files 231 230 -1
Lines 52364 52346 -18
Branches 9217 9212 -5
==========================================
- Hits 42962 42947 -15
- Misses 7094 7095 +1
+ Partials 2308 2304 -4
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved with some nitpick comments
in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk] | ||
|
||
return [in1_bar, in2_bar] | ||
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really slick!
out[0] = s.flatten()[0] | ||
def perform(self, node, inputs, output_storage): | ||
# not using .item() because that returns a Python scalar, not a numpy scalar | ||
output_storage[0][0] = inputs[0][()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That [()]
syntax is really ugly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's how you index/update a 0d array in numpy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also use None
I believe.
Edit: It only works for assignment. If this is really what they want you to do in this case, that sucks. But c'est la vie.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is you're indexing an array without indices / dimensions, so you can't say pick first entry of first dimension (there are no dims). Anyway surprised it bothers you too much
5be6968
to
9026dd8
Compare
Instead of statically parametrizing the convolution type at the Op level, it now uses a scalar boolean that can be set symbolically. This was fine except for JAX, where we can't branch symbolically like that.
As long as the mode is constant (which is the case when the original convolve was full, or the core shapes are statically known) this should be fine, except for one hiccup. In the dispatch of Blockwise we couldn't see the outer inputs.
I added some functionality when we create the dummy core node to propagate inputs if these don't have batch dimensions, which means they won't change over iterations and so making compile or rewrite decisions based on this should be safe. This can also be used for infer_shape for instance, which could help with lowering some Ops to numba
These changes would also allow us to compile a constant convolve mode in C/Numba, but benchmarks didn't show any gains so I didn't bother doing that. In any case, future implementations of Blockwise for certain Ops can make use of that information.
Relevant benchmark tests:
Note the worst case scenario when we were doing a full convolution for the smaller input in the gradient of a valid convolution.
📚 Documentation preview 📚: https://pytensor--1522.org.readthedocs.build/en/1522/