-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Allow function dispatch for constants #1159
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,11 +36,15 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs): | |
|
||
@pytorch_typify.register(slice) | ||
@pytorch_typify.register(NoneType) | ||
@pytorch_typify.register(np.number) | ||
def pytorch_typify_no_conversion_needed(data, **kwargs): | ||
return data | ||
|
||
|
||
@pytorch_typify.register(np.number) | ||
def pytorch_typify_extract(data, **kwargs): | ||
return data.item() | ||
|
||
|
||
@singledispatch | ||
def pytorch_funcify(op, node=None, storage_map=None, **kwargs): | ||
"""Create a PyTorch compatible function from an PyTensor `Op`.""" | ||
|
@@ -57,11 +61,20 @@ def pytorch_funcify_FunctionGraph( | |
conversion_func=pytorch_funcify, | ||
**kwargs, | ||
): | ||
def constants_wrapper(x, **kwargs): | ||
x = pytorch_typify(x) | ||
|
||
@torch.compiler.assume_constant_result | ||
def torch_assume_constant(arg=x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this unfortunately suffers from the same problem some of the generated functions were getting where pytorch loses their reference somewhere in the process. We'll need to make sure it gets added to the |
||
return arg | ||
|
||
return torch_assume_constant | ||
|
||
built_kwargs = {"conversion_func": conversion_func, **kwargs} | ||
return fgraph_to_python( | ||
fgraph, | ||
conversion_func, | ||
type_conversion_fn=pytorch_typify, | ||
type_conversion_fn=constants_wrapper, | ||
fgraph_name=fgraph_name, | ||
**built_kwargs, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -749,9 +749,23 @@ def fgraph_to_python( | |
) | ||
if input_storage[0] is not None or isinstance(i, Constant): | ||
# Constants need to be assigned locally and referenced | ||
global_env[local_input_name] = type_conversion_fn( | ||
getter_or_value = type_conversion_fn( | ||
input_storage[0], variable=i, storage=input_storage, **kwargs | ||
) | ||
if callable(getter_or_value): | ||
# we got passed a function, this could be used to indicate something | ||
# to the backend. We'll embed it | ||
new_output_name = unique_name(i) | ||
getter_unique_name = unique_name(getter_or_value) | ||
global_env[getter_unique_name] = getter_or_value | ||
assign_str = f"{new_output_name} = {getter_unique_name}()" | ||
body_assigns.append(assign_str) | ||
node_input_names.append(new_output_name) | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hopefully we don't need this and it was an over optimization. Try to refactor away the continue. |
||
else: | ||
global_env[local_input_name] = type_conversion_fn( | ||
input_storage[0], variable=i, storage=input_storage, **kwargs | ||
) | ||
# TODO: We could attempt to use the storage arrays directly | ||
# E.g. `local_input_name = f"{local_input_name}[0]"` | ||
node_input_names.append(local_input_name) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Say more?
The torch compiler threw asserts when a zero dim np value was passed back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it going to upcast values to float64, whatever python integers are? Does torch have scalars (not tensors) with specific dtypes we can use instead?