-
Notifications
You must be signed in to change notification settings - Fork 137
Change Dot Op to only accept matrix inputs #1538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,12 +40,13 @@ | |
get_normalized_batch_axes, | ||
scalar_elemwise, | ||
) | ||
from pytensor.tensor.shape import shape, specify_broadcastable | ||
from pytensor.tensor.shape import shape, specify_shape | ||
from pytensor.tensor.type import ( | ||
DenseTensorType, | ||
complex_dtypes, | ||
continuous_dtypes, | ||
discrete_dtypes, | ||
float_dtypes, | ||
int_dtypes, | ||
tensor, | ||
uint_dtypes, | ||
|
@@ -2986,9 +2987,7 @@ def clip(x, min, max): | |
|
||
class Dot(Op): | ||
""" | ||
Computes the dot product of two variables. For two matrices, this is | ||
equivalent to matrix multiplication. For two vectors, this is the inner | ||
product. | ||
Computes the dot product of two matrices variables | ||
|
||
Notes | ||
----- | ||
|
@@ -3001,97 +3000,58 @@ class Dot(Op): | |
|
||
""" | ||
|
||
gufunc_signature = "(m,n),(n,p)->(m,p)" | ||
gufunc_spec = ("matmul", 2, 1) | ||
__props__ = () | ||
|
||
# the rationale for Dot22 is related to getting GEMM Ops into the | ||
# graph. See Dot22 in tensor.blas for details. | ||
|
||
def make_node(self, *inputs): | ||
inputs = list(map(as_tensor_variable, inputs)) | ||
def make_node(self, x, y): | ||
x = as_tensor_variable(x) | ||
y = as_tensor_variable(y) | ||
|
||
if len(inputs) != 2: | ||
raise TypeError(f"Two arguments required, {len(inputs)} given ") | ||
if inputs[0].ndim not in (1, 2): | ||
if x.type.ndim != 2: | ||
raise TypeError( | ||
"Input 0 (0-indexed) must have ndim of " | ||
f"1 or 2, {int(inputs[0].ndim)} given. Consider calling " | ||
"pytensor.tensor.dot instead." | ||
f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions" | ||
) | ||
if inputs[1].ndim not in (1, 2): | ||
if y.type.ndim != 2: | ||
raise TypeError( | ||
"Input 1 (0-indexed) must have ndim of " | ||
f"1 or 2, {int(inputs[1].ndim)} given. Consider calling " | ||
"pytensor.tensor.dot instead." | ||
f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions" | ||
) | ||
|
||
sx, sy = (input.type.shape for input in inputs) | ||
sx, sy = x.type.shape, y.type.shape | ||
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]: | ||
raise ValueError( | ||
f"Incompatible shared dimension for dot product: {sx}, {sy}" | ||
) | ||
out_shape = (sx[0], sy[1]) | ||
out_dtype = ps.upcast(x.type.dtype, y.type.dtype) | ||
outputs = [tensor(dtype=out_dtype, shape=out_shape)] | ||
return Apply(self, [x, y], outputs) | ||
|
||
if len(sy) == 2: | ||
sz = sx[:-1] + sy[-1:] | ||
elif len(sy) == 1: | ||
sz = sx[:-1] | ||
|
||
i_dtypes = [input.type.dtype for input in inputs] | ||
outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)] | ||
return Apply(self, inputs, outputs) | ||
|
||
def perform(self, node, inp, out): | ||
x, y = inp | ||
(z,) = out | ||
|
||
# the asarray is here because dot between two vectors | ||
# gives a numpy float object but we need to return a 0d | ||
# ndarray | ||
z[0] = np.asarray(np.dot(x, y)) | ||
def perform(self, node, inputs, output_storage): | ||
output_storage[0][0] = np.matmul(*inputs) | ||
|
||
def grad(self, inp, grads): | ||
x, y = inp | ||
(gz,) = grads | ||
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim | ||
|
||
# grad is scalar, so x is vector and y is vector | ||
if gdim == 0: | ||
xgrad = gz * y | ||
ygrad = gz * x | ||
|
||
# x is vector, y is matrix, grad is vector | ||
elif xdim == 1 and ydim == 2: | ||
xgrad = dot(gz, y.T) | ||
ygrad = outer(x.T, gz) | ||
|
||
# x is matrix, y is vector, grad is vector | ||
elif xdim == 2 and ydim == 1: | ||
xgrad = outer(gz, y.T) | ||
ygrad = dot(x.T, gz) | ||
|
||
# x is matrix, y is matrix, grad is matrix | ||
elif xdim == ydim == 2: | ||
xgrad = dot(gz, y.T) | ||
ygrad = dot(x.T, gz) | ||
xgrad = self(gz, y.T) | ||
ygrad = self(x.T, gz) | ||
|
||
# If x or y contain broadcastable dimensions but only one of | ||
# them know that a matching dimensions is broadcastable, the | ||
# above code don't always return the right broadcast pattern. | ||
# This cause problem down the road. See gh-1461. | ||
if xgrad.broadcastable != x.broadcastable: | ||
xgrad = specify_broadcastable( | ||
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) | ||
) | ||
if ygrad.broadcastable != y.broadcastable: | ||
ygrad = specify_broadcastable( | ||
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b) | ||
) | ||
|
||
rval = xgrad, ygrad | ||
if xgrad.type.shape != x.type.shape: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the logic the comment is worried about should be implemented in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see how that's related to the grad here |
||
xgrad = specify_shape(xgrad, x.type.shape) | ||
if ygrad.type.shape != y.type.shape: | ||
ygrad = specify_shape(ygrad, y.type.shape) | ||
|
||
for elem in rval: | ||
assert elem.dtype.find("float") != -1 | ||
if xgrad.type.dtype not in float_dtypes: | ||
raise TypeError("Dot grad x output must be a float type") | ||
if ygrad.type.dtype not in float_dtypes: | ||
raise TypeError("Dot grad y output must be a float type") | ||
|
||
return rval | ||
return xgrad, ygrad | ||
|
||
def R_op(self, inputs, eval_points): | ||
# R_op for a \dot b evaluated at c for a and d for b is | ||
|
@@ -3116,24 +3076,7 @@ def R_op(self, inputs, eval_points): | |
|
||
def infer_shape(self, fgraph, node, shapes): | ||
xshp, yshp = shapes | ||
x, y = node.inputs | ||
|
||
# vector / vector | ||
if x.ndim == 1 and y.ndim == 1: | ||
return [()] | ||
# matrix / vector | ||
if x.ndim == 2 and y.ndim == 1: | ||
return [xshp[:-1]] | ||
# vector / matrix | ||
if x.ndim == 1 and y.ndim == 2: | ||
return [yshp[-1:]] | ||
# matrix / matrix | ||
if x.ndim == 2 and y.ndim == 2: | ||
return [xshp[:-1] + yshp[-1:]] | ||
raise NotImplementedError() | ||
|
||
def __str__(self): | ||
return "dot" | ||
return [[xshp[0], yshp[1]]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output shape in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
_dot = Dot() | ||
|
@@ -3215,7 +3158,24 @@ def dense_dot(a, b): | |
elif a.ndim > 2 or b.ndim > 2: | ||
return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]]) | ||
else: | ||
return _dot(a, b) | ||
row_vector = a.ndim == 1 | ||
if row_vector: | ||
# Promote to row matrix | ||
a = a[None] | ||
|
||
col_vector = b.ndim == 1 | ||
if col_vector: | ||
# Promote to column matrix | ||
b = b[:, None] | ||
|
||
out = _dot(a, b) | ||
if row_vector: | ||
# If we promoted a to a row matrix, we need to squeeze the first dimension | ||
out = out.squeeze(0) | ||
if col_vector: | ||
# If we promoted b to a column matrix, we need to squeeze the last dimension | ||
out = out.squeeze(-1) | ||
return out | ||
|
||
|
||
def tensordot( | ||
|
@@ -3921,11 +3881,7 @@ def logsumexp(x, axis=None, keepdims=False): | |
return log(sum(exp(x), axis=axis, keepdims=keepdims)) | ||
|
||
|
||
_matmul = Blockwise( | ||
_dot, | ||
signature="(m,k),(k,n)->(m,n)", | ||
gufunc_spec=("numpy.matmul", 2, 1), | ||
) | ||
_matmul = Blockwise(_dot, name="Matmul") | ||
|
||
|
||
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): | ||
|
@@ -3975,7 +3931,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None | |
if x1.type.ndim == 0 or x2.type.ndim == 0: | ||
raise ValueError("matmul operand cannot be scalar") | ||
if x1.type.ndim == 1 and x2.type.ndim == 1: | ||
out = _dot(x1, x2) | ||
out = vecdot(x1, x2) | ||
elif x1.type.ndim == 1: | ||
out = vecmat(x1, x2) | ||
elif x2.type.ndim == 1: | ||
|
@@ -4139,23 +4095,7 @@ def vecmat( | |
|
||
@_vectorize_node.register(Dot) | ||
def vectorize_node_dot(op, node, batched_x, batched_y): | ||
old_x, old_y = node.inputs | ||
old_x_ndim = old_x.type.ndim | ||
old_y_ndim = old_y.type.ndim | ||
match (old_x_ndim, old_y_ndim): | ||
case (1, 1): | ||
batch_fn = vecdot | ||
case (2, 1): | ||
batch_fn = matvec | ||
case (1, 2): | ||
batch_fn = vecmat | ||
case (2, 2): | ||
batch_fn = matmul | ||
case _: | ||
raise ValueError( | ||
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." | ||
) | ||
return batch_fn(batched_x, batched_y).owner | ||
return matmul(batched_x, batched_y).owner | ||
|
||
|
||
def nan_to_num(x, nan=0.0, posinf=None, neginf=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output should handle the case when A.size == 0. When A is empty, the code should still copy A if not destructive, but currently it just assigns A directly regardless of the destructive flag.
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no point in copying an empty array, you can't store anything in it