diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..99977a5915 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -36,11 +36,15 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs): @pytorch_typify.register(slice) @pytorch_typify.register(NoneType) -@pytorch_typify.register(np.number) def pytorch_typify_no_conversion_needed(data, **kwargs): return data +@pytorch_typify.register(np.number) +def pytorch_typify_extract(data, **kwargs): + return data.item() + + @singledispatch def pytorch_funcify(op, node=None, storage_map=None, **kwargs): """Create a PyTorch compatible function from an PyTensor `Op`.""" @@ -57,11 +61,13 @@ def pytorch_funcify_FunctionGraph( conversion_func=pytorch_funcify, **kwargs, ): + if "type_conversion_fn" not in kwargs: + kwargs["type_conversion_fn"] = pytorch_typify + built_kwargs = {"conversion_func": conversion_func, **kwargs} return fgraph_to_python( fgraph, conversion_func, - type_conversion_fn=pytorch_typify, fgraph_name=fgraph_name, **built_kwargs, ) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..8a6fc8b6f5 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -10,7 +10,9 @@ def __init__(self, *args, **kwargs): self.gen_functors = [] def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): - from pytensor.link.pytorch.dispatch import pytorch_funcify + import torch + + from pytensor.link.pytorch.dispatch import pytorch_funcify, pytorch_typify # We want to have globally unique names # across the entire pytensor graph, not @@ -25,9 +27,21 @@ def conversion_func_register(*args, **kwargs): self.gen_functors.append((f"_{name}", functor)) return functor + def constants_wrapper(x, **kwargs): + x = pytorch_typify(x) + + @torch.compiler.assume_constant_result + def torch_assume_constant(arg=x): + return arg + + name = kwargs["unique_name"](torch_assume_constant) + self.gen_functors.append((f"_{name}", torch_assume_constant)) + return torch_assume_constant + built_kwargs = { "unique_name": generator, "conversion_func": conversion_func_register, + "type_conversion_fn": constants_wrapper, **kwargs, } return pytorch_funcify( diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 9cbc3838dd..142fefc04d 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -749,11 +749,25 @@ def fgraph_to_python( ) if input_storage[0] is not None or isinstance(i, Constant): # Constants need to be assigned locally and referenced - global_env[local_input_name] = type_conversion_fn( + getter_or_value = type_conversion_fn( input_storage[0], variable=i, storage=input_storage, **kwargs ) - # TODO: We could attempt to use the storage arrays directly - # E.g. `local_input_name = f"{local_input_name}[0]"` + if callable(getter_or_value): + # we got passed a function, this could be used to indicate something + # to the backend. We'll embed it + new_output_name = unique_name(i) + getter_unique_name = unique_name(getter_or_value) + global_env[getter_unique_name] = getter_or_value + assign_str = f"{new_output_name} = {getter_unique_name}()" + body_assigns.append(assign_str) + node_input_names.append(new_output_name) + continue + else: + global_env[local_input_name] = type_conversion_fn( + input_storage[0], variable=i, storage=input_storage, **kwargs + ) + # TODO: We could attempt to use the storage arrays directly + # E.g. `local_input_name = f"{local_input_name}[0]"` node_input_names.append(local_input_name) node_output_names = [unique_name(v) for v in node.outputs]