Skip to content

Change Dot Op to only accept matrix inputs #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 13, 2025

This builds on top of #1471 removing extra complexity on our representation of matmul. All dots are now on 2d inputs, and the mat-vec, vec-mat and vec-vec can be detected by introspecting the broadcastable pattern of the inputs. This is information that should never be lost, and not having to worry about variants where it doesn't matter makes our lives easier.

This PR also removes scipy_ger, and uses scipy in the perform method of Ger. This is an artifact from the old Theano times where scipy was an optional dependency.

With the changes the whole concept of Dot22 also looses its meaning. We can remove it next and just port the C-implementation to Dot directly


📚 Documentation preview 📚: https://pytensor--1538.org.readthedocs.build/en/1538/

@ricardoV94 ricardoV94 force-pushed the dot_is_2d branch 2 times, most recently from 2d5ed60 to bb45939 Compare July 22, 2025 10:42
@ricardoV94 ricardoV94 changed the title Canonicalize Dot as a matrix-matrix operation Canonicalize dot as a matrix-matrix operation Jul 22, 2025

@pytest.mark.parametrize("inplace", (True, False), ids=["inplace", "no_inplace"])
@pytest.mark.parametrize("n", [2**7, 2**9, 2**13])
def test_ger_benchmark(n, inplace, benchmark):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because at some point I considered taking away GER in favor of outer multiplication. The benchmark showed this was underperforming

@ricardoV94 ricardoV94 force-pushed the dot_is_2d branch 3 times, most recently from 115c865 to 199396e Compare July 24, 2025 10:16
@ricardoV94 ricardoV94 changed the title Canonicalize dot as a matrix-matrix operation Change Dot Op to only accept matrix inputs Jul 24, 2025
@ricardoV94 ricardoV94 marked this pull request as ready for review July 24, 2025 11:54
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the Dot operation to only accept 2D matrix inputs, simplifying the codebase by removing support for vector inputs and delegating vector operations to helper functions. The changes eliminate the old scipy_ger module and streamline the dot product implementation.

  • Restricts Dot Op to only accept 2D tensors (matrices), removing vector support
  • Removes scipy_ger module and integrates scipy directly into Ger.perform method
  • Updates dot interface function to handle vector promotion to matrices internally

Reviewed Changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
pytensor/tensor/math.py Major refactor of Dot class to only accept 2D inputs, updated dense_dot function
pytensor/tensor/blas.py Updated Ger.perform to use scipy directly, simplified Dot22.perform
pytensor/tensor/blas_scipy.py Removed entire ScipyGer implementation
pytensor/tensor/rewriting/blas_scipy.py Removed scipy-specific BLAS rewrites
pytensor/tensor/rewriting/math.py Updated optimization rules for new Dot constraints
tests/tensor/test_math.py Updated tests to reflect new Dot API and removed vector tests
Comments suppressed due to low confidence (1)

tests/tensor/test_math.py:2011

  • The test is checking that _dot(d1, d2) raises TypeError, but this line should be inside a pytest.raises context manager to properly test the exception.
            _dot(d1, d2)

# Work on transposed system to avoid copying
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
else:
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output should handle the case when A.size == 0. When A is empty, the code should still copy A if not destructive, but currently it just assigns A directly regardless of the destructive flag.

Suggested change
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
else:
# Handle the case where A.size == 0
if self.destructive:
# No-op for destructive mode
pass
else:
# Create a copy for non-destructive mode
A = A.copy()

Copilot uses AI. Check for mistakes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no point in copying an empty array, you can't store anything in it

constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [constant_zero]
):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes both inputs are 2D matrices, but the function doesn't validate the input dimensions. If either x or y has fewer than 2 dimensions, accessing x.shape[0] or y.shape[1] could cause an IndexError.

Copilot uses AI. Check for mistakes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the whole point of this PR. If you see a Dot it must be with two matrices. make_node validates it

Copy link

codecov bot commented Jul 24, 2025

Codecov Report

Attention: Patch coverage is 95.58824% with 6 lines in your changes missing coverage. Please review.

Project coverage is 81.53%. Comparing base (12213d0) to head (aff4a71).

Files with missing lines Patch % Lines
pytensor/tensor/math.py 87.80% 2 Missing and 3 partials ⚠️
pytensor/tensor/rewriting/linalg.py 50.00% 0 Missing and 1 partial ⚠️

❌ Your patch check has failed because the patch coverage (95.58%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1538      +/-   ##
==========================================
+ Coverage   81.49%   81.53%   +0.03%     
==========================================
  Files         232      230       -2     
  Lines       53122    53002     -120     
  Branches     9444     9410      -34     
==========================================
- Hits        43292    43213      -79     
+ Misses       7382     7360      -22     
+ Partials     2448     2429      -19     
Files with missing lines Coverage Δ
pytensor/tensor/basic.py 91.84% <ø> (ø)
pytensor/tensor/blas.py 73.22% <100.00%> (-0.33%) ⬇️
pytensor/tensor/rewriting/blas.py 91.10% <100.00%> (+1.82%) ⬆️
pytensor/tensor/rewriting/math.py 90.30% <100.00%> (+1.02%) ⬆️
pytensor/tensor/rewriting/subtensor_lift.py 90.95% <100.00%> (-0.11%) ⬇️
pytensor/tensor/rewriting/linalg.py 92.06% <50.00%> (-0.02%) ⬇️
pytensor/tensor/math.py 92.87% <87.80%> (+0.08%) ⬆️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant