-
Notifications
You must be signed in to change notification settings - Fork 446
[RFC] Move input transforms to GPyTorch #1372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary: This diff presents a minimal implementation of input transforms in GPyTorch. What this does: * Moves the `transform_inputs` from BoTorch `Model` to GPyTorch `GP` class, with some modifications to explicitly identify whether given inputs are train or test inputs. * Modifies the `InputTransform.forward` call to use `is_training_input` argument instead of `self.training` check to apply the transforms that have `transform_on_train=True`. * Removes `preprocess_transform` method since this is no-longer needed. * For `ExactGP` models, it transforms both train and test inputs in `__call__`. For `train_inputs` it always uses `is_training_input=True`. For generic `inputs`, it uses `is_training_input=self.training` which signals that these are training inputs when the model is in `train` mode, and that these are test inputs when the model is in `eval` mode. * For `ApproximateGP` models, it applies the transform to `inputs` in `__call__` using `is_training_input=self.training`. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms `inducing_points`, thus fixes the previous bug with `inducing_points` getting transformed in `train` but not getting transformed in `eval`. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube). * For BoTorch `SingleTaskVariationalGP`, it moves the `input_transform` attribute down to `_SingleTaskVariationalGP`, which is the actual `ApproximateGP` instance. This makes the transform accessible from GPyTorch. What this doesn't do: * It doesn't do anything about `DeterministicModel`s. Those will still need to deal with their own transforms, which is not implemented here. If we make `Model` inherit from `GP`, we can keep the existing setup with very minimal changes. * It does not clean up the call sites for `self.transform_inputs`. This is just made into a no-op and the clean-up is left for later. * It does not upstream the abstract `InputTransform` classes to GPyTorch. That'll be done if we decide to go forward with this design. * It does not touch `PairwiseGP`. `PairwiseGP` has some non-standard use of input transforms, so it needs an audit to make sure things still work fine. * I didn't look into `ApproximateGP.fantasize`. This may need some changes similar to `ExactGP.get_fantasy_model`. * It does not support `PyroGP` and `DeepGP`. Differential Revision: D39147547 fbshipit-source-id: ed2745b0ff666a13764759e1511a139c228d1d39
|
This pull request was exported from Phabricator. Differential Revision: D39147547 |
|
I like this design!
I think we probably want to have sth like a gpytorch
@gpleiss do you have any high-level feedback on the transform setup (https://github.com/pytorch/botorch/tree/main/botorch/models/transforms) that we'd want to incorporate when upstreaming those? One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset? |
|
@Balandat I really like the botorch API, and this would be super useful to have upstream in GPyTorch!
There probably is an elegant way to do this, but nothing really comes to mind. We should circle back to this at some point, but at the very least a power user could (e.g.) apply a pre-trained NN to the inputs without using the transforms API. |
Summary:
This diff presents a minimal implementation of input transforms in GPyTorch. See cornellius-gp/gpytorch#2114 for GPyTorch side of these changes.
What this does:
transform_inputsfrom BoTorchModelto GPyTorchGPclass, with some modifications to explicitly identify whether given inputs are train or test inputs.InputTransform.forwardcall to useis_training_inputargument instead ofself.trainingcheck to apply the transforms that havetransform_on_train=True.preprocess_transformmethod since this is no-longer needed.ExactGPmodels, it transforms both train and test inputs in__call__. Fortrain_inputsit always usesis_training_input=True. For genericinputs, it usesis_training_input=self.trainingwhich signals that these are training inputs when the model is intrainmode, and that these are test inputs when the model is inevalmode.ApproximateGPmodels, it applies the transform toinputsin__call__usingis_training_input=self.training. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transformsinducing_points, thus fixes the previous bug withinducing_pointsgetting transformed intrainbut not getting transformed ineval. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).SingleTaskVariationalGP, it moves theinput_transformattribute down to_SingleTaskVariationalGP, which is the actualApproximateGPinstance. This makes the transform accessible from GPyTorch.What this doesn't do:
DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we makeModelinherit fromGP, we can keep the existing setup with very minimal changes.self.transform_inputs. This is just made into a no-op and the clean-up is left for later.InputTransformclasses to GPyTorch. That'll be done if we decide to go forward with this design.PairwiseGP.PairwiseGPhas some non-standard use of input transforms, so it needs an audit to make sure things still work fine.ApproximateGP.fantasize. This may need some changes similar toExactGP.get_fantasy_model.PyroGPandDeepGP.Differential Revision: D39147547