-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Allow function dispatch for constants #1159
base: main
Are you sure you want to change the base?
Conversation
e6e2bcf
to
96ec531
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So did it work?
def pytorch_typify_no_conversion_needed(data, **kwargs): | ||
return data | ||
|
||
|
||
@pytorch_typify.register(np.number) | ||
def pytorch_typify_extract(data, **kwargs): | ||
return data.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Say more?
The torch compiler threw asserts when a zero dim np value was passed back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it going to upcast values to float64, whatever python integers are? Does torch have scalars (not tensors) with specific dtypes we can use instead?
It worked for the very simple case. I'm gonna try it with the 8 schools model and see if that works. It also broke a few tests lol |
Interestingly enough we have been considered defining constants as zero input Ops to make code more clean in rewrites (we always have to worry about whether a variable has an owner, which constants and root inputs do not). |
That makes a lot of sense to me. I would be happy to sketch that out. |
x = pytorch_typify(x) | ||
|
||
@torch.compiler.assume_constant_result | ||
def torch_assume_constant(arg=x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this unfortunately suffers from the same problem some of the generated functions were getting where pytorch loses their reference somewhere in the process. We'll need to make sure it gets added to the warpper::gen_functors
list
assign_str = f"{new_output_name} = {getter_unique_name}()" | ||
body_assigns.append(assign_str) | ||
node_input_names.append(new_output_name) | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully we don't need this and it was an over optimization. Try to refactor away the continue.
Description
Pytorch doesn't seem to handle the constants we put in the globals of the generated code well. It can cause graph breaks for things that shouldn't be graph break (example: The index in IncSubtensor when the index is static). This change allows the typify method to instead return a function, and then defer fetching the constants in the generated code as a function call.
A small example:
will generate
torch_assume_constant
can now do whatever it needs to do to allow the backend (torch in this case) to know that the variable is infact a constant.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1159.org.readthedocs.build/en/1159/