Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added dot kron rewrite #1090

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,3 +989,20 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"jax",
position=0.9, # Run before canonicalization
)


@register_canonicalize
@register_stabilize
@node_rewriter([Dot])
def rewrite_dot_kron(fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a docstring and typehints

potential_kron = node.inputs[0].owner
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return False

c = node.inputs[1]
[a, b] = potential_kron.inputs

m, n = a.type.shape
p, q = b.type.shape
out_clever = pt.expand_dims((b @ c.reshape(shape=(n, q)).T @ a.T).T.ravel(), 1)
tanish1729 marked this conversation as resolved.
Show resolved Hide resolved
return [out_clever]
26 changes: 26 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,3 +906,29 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)


def test_dot_kron_rewrite():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test with batch dims. I am worried the use of e.g. ravel in the rewrite will cause problems.

m, n, p, q = 3, 4, 6, 7
a = pt.matrix("a", shape=(m, n))
b = pt.matrix("b", shape=(p, q))
c = pt.matrix("c", shape=(n * q, 1))
out_direct = pt.linalg.kron(a, b) @ c

# REWRITE TEST
f_direct_rewritten = function([a, b, c], out_direct, mode="FAST_RUN")
nodes = f_direct_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)

# NUMERIC VALUE TEST
a_test = np.random.rand(m, n).astype(config.floatX)
b_test = np.random.rand(p, q).astype(config.floatX)
c_test = np.random.rand(n * q, 1).astype(config.floatX)
out_direct_val = np.kron(a_test, b_test) @ c_test
out_clever_val = f_direct_rewritten(a_test, b_test, c_test)
tanish1729 marked this conversation as resolved.
Show resolved Hide resolved
assert_allclose(
out_direct_val,
out_clever_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
Loading