Skip to content

Commit

Permalink
Allow function dispatch for constants
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Schweer committed Jan 21, 2025
1 parent d4a2b2b commit 90880e2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
15 changes: 13 additions & 2 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs):

@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType)
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
return data

@pytorch_typify.register(np.number)
def pytorch_typify_extract(data, **kwargs):
return data.item()

@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
Expand All @@ -57,11 +59,20 @@ def pytorch_funcify_FunctionGraph(
conversion_func=pytorch_funcify,
**kwargs,
):
def constants_wrapper(x, **kwargs):
x = pytorch_typify(x)

@torch.compiler.assume_constant_result
def torch_assume_constant(arg=x):
return arg

return torch_assume_constant

built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=pytorch_typify,
type_conversion_fn=constants_wrapper,
fgraph_name=fgraph_name,
**built_kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class wrapper:
"""

def __init__(self, fn, gen_functors):
self._fn = fn
self.fn = torch.compile(fn)
self.gen_functors = gen_functors.copy()

Expand Down
18 changes: 17 additions & 1 deletion pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,25 @@ def fgraph_to_python(
)
if input_storage[0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn(
getter_or_value = type_conversion_fn(
input_storage[0], variable=i, storage=input_storage, **kwargs
)
if callable(getter_or_value):
# we got passed a function, this could be used to indicate something
# to the backend. We'll embed it
new_output_name = unique_name(i)
getter_unique_name = unique_name(getter_or_value)
global_env[getter_unique_name] = getter_or_value
assign_str = (
f"{new_output_name} = {getter_unique_name}()"
)
body_assigns.append(assign_str)
node_input_names.append(new_output_name)
continue
else:
global_env[local_input_name] = type_conversion_fn(
input_storage[0], variable=i, storage=input_storage, **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)
Expand Down

0 comments on commit 90880e2

Please sign in to comment.