-
Notifications
You must be signed in to change notification settings - Fork 8
Add unit test for existing ac API behavior #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/ailzhang/3/base
Are you sure you want to change the base?
Conversation
|
The first PR is just adding a unit test (with claude :P ) to make sure I understand the current behavior 🙏 Any feedback is appreciated! |
| return Transformer(model_args) | ||
|
|
||
|
|
||
| def create_joint_graph_from_model(model, input_args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this implementation is strange, any issues just using the same joint graph capture frontend as the rest of the repo?
autoparallel/autoparallel/api.py
Lines 301 to 310 in b1c4909
| torch_ir_with_fqn = _export(self.model, inputs) | |
| # TODO Cna't use fake mode here because it clashes with the user level | |
| # fake mode. Ideally dynamo should reuse the user level fake mode. | |
| self.joint_with_descriptors = aot_export_joint_with_descriptors( | |
| self.stack, | |
| torch_ir_with_fqn, | |
| inputs, | |
| decompositions=decomp_table, | |
| ) | |
| gm = self.joint_with_descriptors.graph_module |
| # Define save list with operations that might be in the graph | ||
| save_list = { | ||
| torch.ops.aten.mm.default, | ||
| torch.ops.aten.addmm.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this always gets decomposed away. and if it doesn't, it will mess up your saved node count lol
Stack from ghstack (oldest at bottom):