Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow function dispatch for constants #1159

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Ch0ronomato
Copy link
Contributor

@Ch0ronomato Ch0ronomato commented Jan 21, 2025

Description

Pytorch doesn't seem to handle the constants we put in the globals of the generated code well. It can cause graph breaks for things that shouldn't be graph break (example: The index in IncSubtensor when the index is static). This change allows the typify method to instead return a function, and then defer fetching the constants in the generated code as a function call.

A small example:

import pytensor
import pytensor.tensor as ptt

x = ptt.vector('x')
o = x[-1].inc(1)
f = pytensor.function(inputs=[x], outputs=o, mode="PYTORCH")

will generate

def pytorch_funcified_fgraph(x):
    tensor_constant = torch_assume_constant()
    scalar_constant = torch_assume_constant_1()
    # IncSubtensor{i}(x, 1, -1)
    tensor_variable = inc_subtensor(x, tensor_constant, scalar_constant)
    return (tensor_variable,)

torch_assume_constant can now do whatever it needs to do to allow the backend (torch in this case) to know that the variable is infact a constant.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1159.org.readthedocs.build/en/1159/

@Ch0ronomato Ch0ronomato force-pushed the linker_scalar_handling branch from e6e2bcf to 96ec531 Compare January 21, 2025 03:22
@Ch0ronomato Ch0ronomato changed the title Allow function dispatch for constants [WIP] Allow function dispatch for constants Jan 21, 2025
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So did it work?

def pytorch_typify_no_conversion_needed(data, **kwargs):
return data


@pytorch_typify.register(np.number)
def pytorch_typify_extract(data, **kwargs):
return data.item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say more?

The torch compiler threw asserts when a zero dim np value was passed back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it going to upcast values to float64, whatever python integers are? Does torch have scalars (not tensors) with specific dtypes we can use instead?

@Ch0ronomato
Copy link
Contributor Author

It worked for the very simple case. I'm gonna try it with the 8 schools model and see if that works.

It also broke a few tests lol

@ricardoV94
Copy link
Member

Interestingly enough we have been considered defining constants as zero input Ops to make code more clean in rewrites (we always have to worry about whether a variable has an owner, which constants and root inputs do not).

@Ch0ronomato
Copy link
Contributor Author

Interestingly enough we have been considered defining constants as zero input Ops to make code more clean in rewrites (we always have to worry about whether a variable has an owner, which constants and root inputs do not).

That makes a lot of sense to me. I would be happy to sketch that out.

x = pytorch_typify(x)

@torch.compiler.assume_constant_result
def torch_assume_constant(arg=x):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this unfortunately suffers from the same problem some of the generated functions were getting where pytorch loses their reference somewhere in the process. We'll need to make sure it gets added to the warpper::gen_functors list

assign_str = f"{new_output_name} = {getter_unique_name}()"
body_assigns.append(assign_str)
node_input_names.append(new_output_name)
continue
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully we don't need this and it was an over optimization. Try to refactor away the continue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IncSubtensor causes graph break in pytorch backend
2 participants