Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow function dispatch for constants #1159

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -36,11 +36,15 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs):

@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType)
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
return data


@pytorch_typify.register(np.number)
def pytorch_typify_extract(data, **kwargs):
return data.item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say more?

The torch compiler threw asserts when a zero dim np value was passed back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it going to upcast values to float64, whatever python integers are? Does torch have scalars (not tensors) with specific dtypes we can use instead?



@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
@@ -57,11 +61,20 @@ def pytorch_funcify_FunctionGraph(
conversion_func=pytorch_funcify,
**kwargs,
):
def constants_wrapper(x, **kwargs):
x = pytorch_typify(x)

@torch.compiler.assume_constant_result
def torch_assume_constant(arg=x):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this unfortunately suffers from the same problem some of the generated functions were getting where pytorch loses their reference somewhere in the process. We'll need to make sure it gets added to the warpper::gen_functors list

return arg

return torch_assume_constant

built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=pytorch_typify,
type_conversion_fn=constants_wrapper,
fgraph_name=fgraph_name,
**built_kwargs,
)
16 changes: 15 additions & 1 deletion pytensor/link/utils.py
Original file line number Diff line number Diff line change
@@ -749,9 +749,23 @@ def fgraph_to_python(
)
if input_storage[0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn(
getter_or_value = type_conversion_fn(
input_storage[0], variable=i, storage=input_storage, **kwargs
)
if callable(getter_or_value):
# we got passed a function, this could be used to indicate something
# to the backend. We'll embed it
new_output_name = unique_name(i)
getter_unique_name = unique_name(getter_or_value)
global_env[getter_unique_name] = getter_or_value
assign_str = f"{new_output_name} = {getter_unique_name}()"
body_assigns.append(assign_str)
node_input_names.append(new_output_name)
continue
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully we don't need this and it was an over optimization. Try to refactor away the continue.

else:
global_env[local_input_name] = type_conversion_fn(
input_storage[0], variable=i, storage=input_storage, **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name)