From 90880e248fe9294e44c27f33b883fc0608d6714d Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Mon, 20 Jan 2025 07:50:10 -0800 Subject: [PATCH] Allow function dispatch for constants --- pytensor/link/pytorch/dispatch/basic.py | 15 +++++++++++++-- pytensor/link/pytorch/linker.py | 1 + pytensor/link/utils.py | 18 +++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..5ec5a366d6 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -36,10 +36,12 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs): @pytorch_typify.register(slice) @pytorch_typify.register(NoneType) -@pytorch_typify.register(np.number) def pytorch_typify_no_conversion_needed(data, **kwargs): return data +@pytorch_typify.register(np.number) +def pytorch_typify_extract(data, **kwargs): + return data.item() @singledispatch def pytorch_funcify(op, node=None, storage_map=None, **kwargs): @@ -57,11 +59,20 @@ def pytorch_funcify_FunctionGraph( conversion_func=pytorch_funcify, **kwargs, ): + def constants_wrapper(x, **kwargs): + x = pytorch_typify(x) + + @torch.compiler.assume_constant_result + def torch_assume_constant(arg=x): + return arg + + return torch_assume_constant + built_kwargs = {"conversion_func": conversion_func, **kwargs} return fgraph_to_python( fgraph, conversion_func, - type_conversion_fn=pytorch_typify, + type_conversion_fn=constants_wrapper, fgraph_name=fgraph_name, **built_kwargs, ) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..4a5acd5b85 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -51,6 +51,7 @@ class wrapper: """ def __init__(self, fn, gen_functors): + self._fn = fn self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 9cbc3838dd..d02398f85b 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -749,9 +749,25 @@ def fgraph_to_python( ) if input_storage[0] is not None or isinstance(i, Constant): # Constants need to be assigned locally and referenced - global_env[local_input_name] = type_conversion_fn( + getter_or_value = type_conversion_fn( input_storage[0], variable=i, storage=input_storage, **kwargs ) + if callable(getter_or_value): + # we got passed a function, this could be used to indicate something + # to the backend. We'll embed it + new_output_name = unique_name(i) + getter_unique_name = unique_name(getter_or_value) + global_env[getter_unique_name] = getter_or_value + assign_str = ( + f"{new_output_name} = {getter_unique_name}()" + ) + body_assigns.append(assign_str) + node_input_names.append(new_output_name) + continue + else: + global_env[local_input_name] = type_conversion_fn( + input_storage[0], variable=i, storage=input_storage, **kwargs + ) # TODO: We could attempt to use the storage arrays directly # E.g. `local_input_name = f"{local_input_name}[0]"` node_input_names.append(local_input_name)