-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Labels
Description
Description
When Einsum can't be optimized (because we don't know the static shapes) it stays as an OpFromGraph. We could replace it by a COp (as a cxx_only rewrite) in this case, that calls the numpy C function:
https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_EinsteinSum
pytensor/pytensor/tensor/rewriting/einsum.py
Lines 39 to 53 in f25a624
@register_specialize | |
@node_rewriter([Einsum]) | |
def inline_optimized_einsum( | |
fgraph: FunctionGraph, node: Apply | |
) -> list[TensorVariable] | None: | |
"""Inline einsums that are already optimized. | |
This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right. | |
""" | |
op: Einsum = node.op | |
if not op.optimized: | |
return None | |
return cast(list[TensorVariable], inline_ofg_node(node)) |