Skip to content

Improve dot lift rewrites #1471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 21 additions & 32 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,11 @@
)

sx, sy = (input.type.shape for input in inputs)
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
raise ValueError(

Check warning on line 3029 in pytensor/tensor/math.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/math.py#L3029

Added line #L3029 was not covered by tests
f"Incompatible shared dimension for dot product: {sx}, {sy}"
)

if len(sy) == 2:
sz = sx[:-1] + sy[-1:]
elif len(sy) == 1:
Expand Down Expand Up @@ -3916,23 +3921,7 @@
return log(sum(exp(x), axis=axis, keepdims=keepdims))


# Predefine all batched variations of Dot
_inner_prod = Blockwise(
_dot,
signature="(n),(n)->()",
)

_matrix_vec_prod = Blockwise(
_dot,
signature="(m,k),(k)->(m)",
)

_vec_matrix_prod = Blockwise(
_dot,
signature="(k),(k,n)->(n)",
)

_matrix_matrix_matmul = Blockwise(
_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
Expand Down Expand Up @@ -3988,11 +3977,11 @@
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
out = vecmat(x1, x2)
elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
out = matvec(x1, x2)
else:
out = _matrix_matrix_matmul(x1, x2)
out = _matmul(x1, x2)

if dtype is not None:
out = out.astype(dtype)
Expand Down Expand Up @@ -4042,7 +4031,7 @@
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
"""
out = _inner_prod(x1, x2)
out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1))

if dtype is not None:
out = out.astype(dtype)
Expand Down Expand Up @@ -4091,7 +4080,7 @@
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
"""
out = _matrix_vec_prod(x1, x2)
out = matmul(x1, x2[..., None]).squeeze(-1)

if dtype is not None:
out = out.astype(dtype)
Expand Down Expand Up @@ -4129,18 +4118,18 @@
--------
>>> import pytensor.tensor as pt
>>> # Vector-matrix product
>>> v = pt.vector("v", shape=(3,)) # shape (3,)
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
>>> v = pt.vector("v", shape=(3,))
>>> A = pt.matrix("A", shape=(3, 4))
>>> result = pt.vecmat(v, A) # shape (4,)
>>> # Equivalent to numpy.vecmat(v, A)
>>>
>>> # Batched vector-matrix product
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3)
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
>>> batched_v = pt.matrix("v", shape=(2, 3))
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
"""
out = _vec_matrix_prod(x1, x2)
out = matmul(x2.mT, x1[..., None]).squeeze(-1)

if dtype is not None:
out = out.astype(dtype)
Expand All @@ -4155,18 +4144,18 @@
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_op = _inner_prod
batch_fn = vecdot
case (2, 1):
batch_op = _matrix_vec_prod
batch_fn = matvec
case (1, 2):
batch_op = _vec_matrix_prod
batch_fn = vecmat
case (2, 2):
batch_op = _matrix_matrix_matmul
batch_fn = matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_op(batched_x, batched_y).owner
return batch_fn(batched_x, batched_y).owner


def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
Expand Down
13 changes: 9 additions & 4 deletions pytensor/tensor/rewriting/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import (
Dot,
_matrix_matrix_matmul,
_matmul,
add,
mul,
neg,
Expand Down Expand Up @@ -758,7 +758,7 @@
ignore_newtrees=False,
),
"fast_run",
position=15,
position=11,
)


Expand Down Expand Up @@ -903,19 +903,23 @@
"local_dot22_to_dot22scalar",
in2out(local_dot22_to_dot22scalar),
"fast_run",
position=11,
position=12,
)


@register_specialize
@node_rewriter([_matrix_matrix_matmul])
@node_rewriter([_matmul])
def specialize_matmul_to_batched_dot(fgraph, node):
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.

TODO: Do the same for Blockwise BatchedDot
"""
x, y = node.inputs

if x.type.ndim < 3:
# This doesn't actually have a batch dimension
return None

Check warning on line 921 in pytensor/tensor/rewriting/blas.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/blas.py#L921

Added line #L921 was not covered by tests

# BatchedDot does not allow implicit broadcasting of the batch dimensions
# We do not want to explicitly broadcast as it may result in huge arrays
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
Expand All @@ -926,6 +930,7 @@
if len(x_shape) > 3:
# If we have more than one batch dim, ravel it
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
if len(y_shape) > 3:
y = y.reshape((-1, y_shape[-2], y_shape[-1]))

new_out = _batched_dot(x, y)
Expand Down
95 changes: 33 additions & 62 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,17 @@
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
MakeVector,
alloc,
cast,
constant,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import (
alloc_like,
broadcasted_by,
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, TensorVariable


Expand Down Expand Up @@ -346,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_lift(fgraph, node):
Expand Down Expand Up @@ -434,66 +431,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):

"""
if len(node.outputs) > 1:
return
try:
shape_i = fgraph.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, Elemwise):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
ps.upgrade_to_float,
ps.upcast_out,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
if i.type.dtype == output_dtype:
new_inputs.append(i)
else:
try:
cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
return None

if getattr(node.op.scalar_op, "output_types_preference", None) not in (
ps.upgrade_to_float,
ps.upcast_out,
):
return None

if new_inputs != node.inputs:
rval = [node.op(*new_inputs)]
if not node.outputs[0].type.is_super(rval[0].type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
[old_out] = node.outputs
output_dtype = old_out.type.dtype
new_inputs = list(node.inputs)
changed = False
for i, inp in enumerate(node.inputs):
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
new_inputs[i] = constant(inp.data.astype(output_dtype))
changed = True

if not changed:
return None

# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
rval = node.op(*new_inputs)
if not old_out.type.is_super(rval.type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return None

# Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval)
return rval
# Copy over output stacktrace from before upcasting
copy_stack_trace(old_out, rval)
return [rval]


@node_rewriter([add, mul])
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
Expand Down Expand Up @@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matrix_matrix_matmul)
or (A.owner.op == _matmul)
)
):
return
Expand Down
Loading