Skip to content

Change Dot Op to only accept matrix inputs #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
from pytensor.tensor import (
blas,
blas_c,
blas_scipy,
sharedvar,
xlogx,
)
Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,8 +1801,7 @@ def do_constant_folding(self, fgraph, node):
| pytensor.tensor.blas.Gemv
| pytensor.tensor.blas_c.CGemv
| pytensor.tensor.blas.Ger
| pytensor.tensor.blas_c.CGer
| pytensor.tensor.blas_scipy.ScipyGer,
| pytensor.tensor.blas_c.CGer,
)
):
# Ops that will work inplace on the Alloc. So if they
Expand Down
36 changes: 14 additions & 22 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from pathlib import Path

import numpy as np
from scipy.linalg import get_blas_funcs

from pytensor.graph import vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple
Expand Down Expand Up @@ -288,18 +289,17 @@ def make_node(self, A, alpha, x, y):

return Apply(self, inputs, [A.type()])

def perform(self, node, inp, out):
cA, calpha, cx, cy = inp
(cZ,) = out
if self.destructive:
A = cA
else:
A = cA.copy()
if calpha != 1:
A += calpha * np.outer(cx, cy)
else:
A += np.outer(cx, cy)
cZ[0] = A
def perform(self, node, inputs, output_storage):
A, alpha, x, y = inputs
if A.size:
# GER doesn't handle zero-sized inputs
ger_func = get_blas_funcs("ger", dtype=A.dtype)
if A.flags["C_CONTIGUOUS"]:
# Work on transposed system to avoid copying
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
else:
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output should handle the case when A.size == 0. When A is empty, the code should still copy A if not destructive, but currently it just assigns A directly regardless of the destructive flag.

Suggested change
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
else:
# Handle the case where A.size == 0
if self.destructive:
# No-op for destructive mode
pass
else:
# Create a copy for non-destructive mode
A = A.copy()

Copilot uses AI. Check for mistakes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no point in copying an empty array, you can't store anything in it

output_storage[0][0] = A

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
Expand Down Expand Up @@ -1128,16 +1128,8 @@ def make_node(self, x, y):
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
return Apply(self, [x, y], outputs)

def perform(self, node, inp, out):
x, y = inp
(z,) = out
try:
z[0] = np.asarray(np.dot(x, y))
except ValueError as e:
# The error raised by numpy has no shape information, we mean to
# add that
e.args = (*e.args, x.shape, y.shape)
raise
def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.dot(*inputs)

def infer_shape(self, fgraph, node, input_shapes):
return [[input_shapes[0][0], input_shapes[1][1]]]
Expand Down
34 changes: 0 additions & 34 deletions pytensor/tensor/blas_scipy.py

This file was deleted.

164 changes: 52 additions & 112 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@
get_normalized_batch_axes,
scalar_elemwise,
)
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.shape import shape, specify_shape
from pytensor.tensor.type import (
DenseTensorType,
complex_dtypes,
continuous_dtypes,
discrete_dtypes,
float_dtypes,
int_dtypes,
tensor,
uint_dtypes,
Expand Down Expand Up @@ -2986,9 +2987,7 @@ def clip(x, min, max):

class Dot(Op):
"""
Computes the dot product of two variables. For two matrices, this is
equivalent to matrix multiplication. For two vectors, this is the inner
product.
Computes the dot product of two matrices variables

Notes
-----
Expand All @@ -3001,97 +3000,58 @@ class Dot(Op):

"""

gufunc_signature = "(m,n),(n,p)->(m,p)"
gufunc_spec = ("matmul", 2, 1)
__props__ = ()

# the rationale for Dot22 is related to getting GEMM Ops into the
# graph. See Dot22 in tensor.blas for details.

def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
def make_node(self, x, y):
x = as_tensor_variable(x)
y = as_tensor_variable(y)

if len(inputs) != 2:
raise TypeError(f"Two arguments required, {len(inputs)} given ")
if inputs[0].ndim not in (1, 2):
if x.type.ndim != 2:
raise TypeError(
"Input 0 (0-indexed) must have ndim of "
f"1 or 2, {int(inputs[0].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions"
)
if inputs[1].ndim not in (1, 2):
if y.type.ndim != 2:
raise TypeError(
"Input 1 (0-indexed) must have ndim of "
f"1 or 2, {int(inputs[1].ndim)} given. Consider calling "
"pytensor.tensor.dot instead."
f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions"
)

sx, sy = (input.type.shape for input in inputs)
sx, sy = x.type.shape, y.type.shape
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
raise ValueError(
f"Incompatible shared dimension for dot product: {sx}, {sy}"
)
out_shape = (sx[0], sy[1])
out_dtype = ps.upcast(x.type.dtype, y.type.dtype)
outputs = [tensor(dtype=out_dtype, shape=out_shape)]
return Apply(self, [x, y], outputs)

if len(sy) == 2:
sz = sx[:-1] + sy[-1:]
elif len(sy) == 1:
sz = sx[:-1]

i_dtypes = [input.type.dtype for input in inputs]
outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)]
return Apply(self, inputs, outputs)

def perform(self, node, inp, out):
x, y = inp
(z,) = out

# the asarray is here because dot between two vectors
# gives a numpy float object but we need to return a 0d
# ndarray
z[0] = np.asarray(np.dot(x, y))
def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.matmul(*inputs)

def grad(self, inp, grads):
x, y = inp
(gz,) = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim

# grad is scalar, so x is vector and y is vector
if gdim == 0:
xgrad = gz * y
ygrad = gz * x

# x is vector, y is matrix, grad is vector
elif xdim == 1 and ydim == 2:
xgrad = dot(gz, y.T)
ygrad = outer(x.T, gz)

# x is matrix, y is vector, grad is vector
elif xdim == 2 and ydim == 1:
xgrad = outer(gz, y.T)
ygrad = dot(x.T, gz)

# x is matrix, y is matrix, grad is matrix
elif xdim == ydim == 2:
xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz)
xgrad = self(gz, y.T)
ygrad = self(x.T, gz)

# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = specify_broadcastable(
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
if ygrad.broadcastable != y.broadcastable:
ygrad = specify_broadcastable(
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
)

rval = xgrad, ygrad
if xgrad.type.shape != x.type.shape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the logic the comment is worried about should be implemented in the make_node, so that these checks here aren't necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how that's related to the grad here

xgrad = specify_shape(xgrad, x.type.shape)
if ygrad.type.shape != y.type.shape:
ygrad = specify_shape(ygrad, y.type.shape)

for elem in rval:
assert elem.dtype.find("float") != -1
if xgrad.type.dtype not in float_dtypes:
raise TypeError("Dot grad x output must be a float type")
if ygrad.type.dtype not in float_dtypes:
raise TypeError("Dot grad y output must be a float type")

return rval
return xgrad, ygrad

def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluated at c for a and d for b is
Expand All @@ -3116,24 +3076,7 @@ def R_op(self, inputs, eval_points):

def infer_shape(self, fgraph, node, shapes):
xshp, yshp = shapes
x, y = node.inputs

# vector / vector
if x.ndim == 1 and y.ndim == 1:
return [()]
# matrix / vector
if x.ndim == 2 and y.ndim == 1:
return [xshp[:-1]]
# vector / matrix
if x.ndim == 1 and y.ndim == 2:
return [yshp[-1:]]
# matrix / matrix
if x.ndim == 2 and y.ndim == 2:
return [xshp[:-1] + yshp[-1:]]
raise NotImplementedError()

def __str__(self):
return "dot"
return [[xshp[0], yshp[1]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output shape in make_node seemed to be worried about batch dims, but infer_shape doesn't have to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer_shape takes a well defined node and input shapes. Otherwise every op.infer_shape would be skeptical that it got the right number of shapes, and we don't do that



_dot = Dot()
Expand Down Expand Up @@ -3215,7 +3158,24 @@ def dense_dot(a, b):
elif a.ndim > 2 or b.ndim > 2:
return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]])
else:
return _dot(a, b)
row_vector = a.ndim == 1
if row_vector:
# Promote to row matrix
a = a[None]

col_vector = b.ndim == 1
if col_vector:
# Promote to column matrix
b = b[:, None]

out = _dot(a, b)
if row_vector:
# If we promoted a to a row matrix, we need to squeeze the first dimension
out = out.squeeze(0)
if col_vector:
# If we promoted b to a column matrix, we need to squeeze the last dimension
out = out.squeeze(-1)
return out


def tensordot(
Expand Down Expand Up @@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))


_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
)
_matmul = Blockwise(_dot, name="Matmul")


def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
Expand Down Expand Up @@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("matmul operand cannot be scalar")
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
out = vecdot(x1, x2)
elif x1.type.ndim == 1:
out = vecmat(x1, x2)
elif x2.type.ndim == 1:
Expand Down Expand Up @@ -4139,23 +4095,7 @@ def vecmat(

@_vectorize_node.register(Dot)
def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
old_x_ndim = old_x.type.ndim
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_fn = vecdot
case (2, 1):
batch_fn = matvec
case (1, 2):
batch_fn = vecmat
case (2, 2):
batch_fn = matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_fn(batched_x, batched_y).owner
return matmul(batched_x, batched_y).owner


def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
Expand Down
1 change: 0 additions & 1 deletion pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytensor.tensor.rewriting.basic
import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise
Expand Down
Loading