Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "fit only pytorch models" flag #291

Open
radka-j opened this issue Feb 3, 2025 · 5 comments
Open

Add "fit only pytorch models" flag #291

radka-j opened this issue Feb 3, 2025 · 5 comments
Assignees

Comments

@radka-j
Copy link
Member

radka-j commented Feb 3, 2025

Allow users to easily select to only fit emulators that have a PyTorch backend (currently this is GPs and CNPs). This is useful in cases where a downstream task relies on this.

@mastoffel
Copy link
Collaborator

Just adding here, that when enabling this function, we should deactivate data pre-processing as this is done through sci-kit learn pipelines. We can then throw a warning to tell the user to do this manually beforehand.

@mastoffel
Copy link
Collaborator

A bit of a brain dump here so that we can discuss this: I've just had a look at this issue and #295 (i.e. running and extracting PyTorch models). It's not as straightforward as I thought. This is because both Neural Processes and GPs need objects/data outside of the PyTorch object itself and both are quite specific. AutoEmulate handles these things in the background, but by strapping away the estimator object and returning the pure PyTorch object, we lose this functionality and leave it to the user to figure it out.

  1. The CNP forward method needs context and targets points:

def forward(self, X_context, y_context, X_target=None, context_mask=None):

AutoEmulate internally just takes the training data as context points, so in the predict method the user only has to provide targets X. To do training, the CNP uses a dataset which coordinates the sampling and is slightly weird, as it creates a meta-dataset from a normal dataset. So to do further training, the user would need that object too I guess. For a Attentive CNP the user would also need context_mask for training (which the PyTorch dataset in AutoEmulate also takes care of.

  1. The GP PyTorch object returns a MultivariateNormal object, rather then posterior predictions. To get those, we need a likelihood function, see here. We can extract this one from the object though.

So the question is how to go ahead. A few thoughts:

  • extract all relevant objects and return them in a tuple. This means all PyTorch models will have different outputs and might be confusing.
  • extract only the main PyTorch object and write tutorials on how to use them further
  • something else?

Would be great to get your input here @marjanfamili @radka-j

@radka-j
Copy link
Member Author

radka-j commented Feb 11, 2025

@mastoffel can you point me to where in the code autoemulate does these additional steps (I just don't know the codebase very well yet)?

@mastoffel
Copy link
Collaborator

  • the PyTorch model underlying the CNP (and attentive CNP) need X_context, y_context, X_target, where X_target is the data to predict on and the context tensors are the training data. AutoEmulate just uses the training data which it saved as attributes as context data, see the predict function in condition_neural_process.py

  • the PyTorch model underlying the GP works on it's own with just inputs X, but it returns a gpytorch.distributions.MultitaskMultivariateNormal which is a distribution. To get actual posterior values we need to provide a likelihood, which is done in the fit method in the wrapper class. Here, we wrap the PyTorch in a skorch ExtractGPRegressor, which we provide with the likelihood and training specifics (n epochs, optimizer etc.)

So to do a proper forward pass from inputs X to outputs y, we need additional data (contexts for CNP, Likelihood for GPs), which the user has to figure out themselves if we only provide the pure fitted PyTorch model.

@radka-j
Copy link
Member Author

radka-j commented Feb 11, 2025

Thanks, this is really helpful!

I think we have to return all as a single torch object that has the data it needs and a predict() method that does the right things with them. How difficult would this be?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants