You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
At this stage self.data_splitter.setup() has not been called yet and PyTorch Lightning expects to call self.data_splitter.setup() later. So we need to copy self.data_splitter, call self.data_splitter.setup() on a copy and create the dataloader needed for this callback.
Using model.train() is probably better because the model would have all parameters created before Lightning sees it.
Hey sorry, took a look at this and forgot to respond. I think it makes sense to add the fixes to train instead of the TrainRunner since this will be specific to Pyro models. Happy to take a PR if you'd like to take a stab at it!
Sounds good! Later this week, I will make a PR about this issue - as well as another issue with the second GuideWarmup callback (pyro doesn't track deterministic variables initialised after setup).
I think we need to get rid of both pyro GuideWarmup callbacks and just run guide once in model.train(). This would break how people use them now but IMO a better solution.
PyroModelGuideWarmup fails on GPU probably because
Callback.setup()
is called in the accelerator environment in the latest PyTorch Lightning.This test fails on GPU:
Versions:
The text was updated successfully, but these errors were encountered: