Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise error if "initialize" method is called in eval mode. #1610

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

AustinT
Copy link
Contributor

@AustinT AustinT commented May 6, 2021

This PR is in response to #1556, wherein updating kernel hyperparameters while in eval mode leads to incorrect predictions because the cached values in model.prediction_strategy are not updated. As suggested by @jacobrgardner a simple way to avoid this is to throw an error if the parameters are updated in eval mode. I've implemented this by adding a small check in Module.initialize.

Pros of this implementation:

  • Affects all setter methods in one place (since they all seem to call self.initialize)
  • Simpler than writing separate code for each setter method

Cons of this implementation:

  • The check performed isn't exactly right: technically it only checks whether the kernel is in eval mode, when the actual problem is when the GP model is in eval mode since the GP model holds the problematic cached values. Although uncommon, it is possible for the .training attribute to be different between submodules (for example if the user manually sets them). It is unclear if this can be avoided since the kernel modules don't seem to contain a reference to the GP model that contains them.
  • Users can still bypass this if they set the raw_* tensors directly (although maybe this isn't a bad thing)

@jpchen
Copy link
Contributor

jpchen commented Jun 3, 2021

FWIW I'd prefer to reset caches if parameters change/on initialize. We have a library wrapping a gpytorch model for fully bayesian inference, and often I load multiple sets of samples post-inference when prototyping. Not the biggest deal, but nicer than wrapping everything with train()/eval() blocks.

@AustinT
Copy link
Contributor Author

AustinT commented Jun 6, 2021

FWIW I'd prefer to reset caches if parameters change/on initialize. We have a library wrapping a gpytorch model for fully bayesian inference, and often I load multiple sets of samples post-inference when prototyping. Not the biggest deal, but nicer than wrapping everything with train()/eval() blocks.

I agree that this might be a nicer solution. However, AFAIK the caches are all stored in the GP class and not the kernel class, and I don't know if the kernel classes contain any references to the GP class that contains them, and therefore wouldn't be able to delete the cache unless the code structure was refactored.

@jacobrgardner
Copy link
Member

@jpchen @AustinT when you call initialize to load parameters for the kernel, are you typically calling initialize on the kernel (e.g., model.covar_module.initialize(lengthscale=...)) or on the model (e.g. model.initialize(**{'covar_module.lengthscale': ...}))?

In the latter case we could clear the caches by overriding initialize on ExactGP instead of the solution here, and then override it on Kernel to throw an error suggesting people initialize on the model if they want to do it in eval mode.

@AustinT
Copy link
Contributor Author

AustinT commented Jun 7, 2021

@jpchen @AustinT when you call initialize to load parameters for the kernel, are you typically calling initialize on the kernel (e.g., model.covar_module.initialize(lengthscale=...)) or on the model (e.g. model.initialize(**{'covar_module.lengthscale': ...}))?

In the latter case we could clear the caches by overriding initialize on ExactGP instead of the solution here, and then override it on Kernel to throw an error suggesting people initialize on the model if they want to do it in eval mode.

Personally I've usually just used the setters (i.e. model.covar_module.outputscale = 5.0). I believe this ends up calling the kernel's initialize method, not the model's. I like your proposed solution but it may still necessitate a lot of train/test switching which could be annoying to some people. I'm not sure what the best solution is here 🤷‍♂️

@gpleiss
Copy link
Member

gpleiss commented Jun 9, 2021

@AustinT @jacobrgardner @jpchen one possible (but maybe overly-complicated) solution would be to give each module a UUID that is updated every time some sort of state is changed (e.g. calling "initialize", going into training mode, etc.). We could then update the @cached decorators to check the UUIDs of the child modules, and delete the cache if any of them have changed.

Of course, this would be a major undertaking. Not sure if there are simpler solutions, but we could probably use an overhaul of our caching code anyways.

@jpchen
Copy link
Contributor

jpchen commented Jun 9, 2021

when you call initialize to load parameters for the kernel, are you typically calling initialize on the kernel ... or on the model?

I call initialize on the model, which then recursively calls the kernels. I was hoping there could be something simple like checking if there is a mismatch between cache and parameter values when the model is called and if so, updating the cache, but @jacobrgardner's workaround seems like a reasonable compromise between a larger cache rewrite and just throwing an error.

@AustinT
Copy link
Contributor Author

AustinT commented Jun 11, 2021

@AustinT @jacobrgardner @jpchen one possible (but maybe overly-complicated) solution would be to give each module a UUID that is updated every time some sort of state is changed (e.g. calling "initialize", going into training mode, etc.). We could then update the @cached decorators to check the UUIDs of the child modules, and delete the cache if any of them have changed.

Of course, this would be a major undertaking. Not sure if there are simpler solutions, but we could probably use an overhaul of our caching code anyways.

I like the motivation for this. Would a simpler solution simply be to give each child module a reference to its parent module (perhaps in the parent's setattr function)? Then when a child is updated it could recursively delete/change the caches of the parent modules.

@jacobrgardner
Copy link
Member

I think changing Module's setattr to give each module a parent reference would fix this in the majority of cases. The only edge case I can think of is if someone makes a gpytorch.Module a child of a torch.nn.Module, which I can't see any real use cases for -- most of the time, it's torch.nn.Module children of gpytorch.Module (e.g., feature extractors).

Then again, I don't think any solution we come up with will catch the case of like, someone modifying the parameters of a NN feature extractor on a torch.nn.Module out from under a GP in eval mode, so maybe the setattr solution is the best we can do.

@AustinT
Copy link
Contributor Author

AustinT commented Jul 23, 2021

I think changing Module's setattr to give each module a parent reference would fix this in the majority of cases. The only edge case I can think of is if someone makes a gpytorch.Module a child of a torch.nn.Module, which I can't see any real use cases for -- most of the time, it's torch.nn.Module children of gpytorch.Module (e.g., feature extractors).

Then again, I don't think any solution we come up with will catch the case of like, someone modifying the parameters of a NN feature extractor on a torch.nn.Module out from under a GP in eval mode, so maybe the setattr solution is the best we can do.

I never even thought about the case of changing the underlying parameters of a nn.Module attribute. Maybe the best solution is to warn users more explicitly not to change hyperparameters in eval mode once caches have been computed? This could be stated very clearly on some doc pages or in the example notebooks for example.

Another solution I can think of is to have a setting on the base GP class to always recompute cached values (aka not use caching)? Users could then turn this on or off depending on what they are doing. Perhaps even having caches off by default might be a good choice for people who are playing around with hyperparameters in jupyter notebooks...

The setattr solution you proposed seems like a good idea though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants