You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Recently, torch.compile() started using FakeTensors for both input and weight during compilation. That means that temporary FakeTensor weights are created from original Tensor weights. infshape attributes are not copied to these FakeTensor weights.
Consequently, during compilation, MuReadout.forward() and MuReadout.width_mult() trip this assert and the compilation fails.
This unwanted sideeffect will also influence the ability to eg. export mup models to ONNX.
Any advice how to circumvent missing infshapes on FakeTensors going forward?
Is mup compatible with torch.compile() in Pytorch 2? If yes, what is the correct usage (e.g. should we apply mup before compile or after)?
The text was updated successfully, but these errors were encountered: