Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to pyro model initialisation & sampling [WIP] #2695

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

vitkl
Copy link
Contributor

@vitkl vitkl commented Apr 7, 2024

Addresses #2616

Replaces #1805

@vitkl
Copy link
Contributor Author

vitkl commented Apr 8, 2024

I don't fully understand the reason for the errors - they don't happen in test_pyro_bayesian_regression_low_level, test_pyro_bayesian_regression, test_pyro_bayesian_regression_jit - but they happen when using train() directly. This approach works for cell2location.

The difference maybe the timing when the plates are first used. I will look into this later.

@vitkl
Copy link
Contributor Author

vitkl commented Apr 8, 2024

Also this code for posterior sampling is indeed ~2-3x faster but it creates samples of huge observed data matrixes (copies data n_samples times - eg 1000):

        if isinstance(self.module.guide, poutine.messenger.Messenger):
            # This already includes trace-replay behavior.
            sample = self.module.guide(*args, **kwargs)

An alternative way to deal with this issue would be this:

        if isinstance(self.module.guide, poutine.messenger.Messenger):
            # This already includes trace-replay behavior.
            sample = self.module.guide(*args, **kwargs)
            # include and exclude requested sites
            sample = {k: v for k, v in sample.items() if k in return_sites}
            sample = {k: v for k, v in sample.items() if k not in exclude_vars}   # this has to be provided by model developer

@martinkim0 What do you think we should do? What do you think about the initialisation solution?

@martinkim0 martinkim0 self-assigned this Apr 8, 2024
@martinkim0
Copy link
Contributor

@vitkl hey sorry for the delay, I'm planning on taking a look at this tomorrow!

@martinkim0
Copy link
Contributor

This is actually my first time at taking a look at some of our Pyro code - I hadn't really interacted with it before. So I don't really understand the reason why some things are done, e.g., the warmup callbacks. I definitely need to take a deep dive into all of this.

However, it looks like both PyroJitGuideWarmup and PyroModelGuideWarmup are just passing in a single minibatch through the guide prior to the training loop, so I like the idea of having a method like setup_pyro_model that does this. I think this makes more sense in the training plan though, using one of the Lightning hooks such as on_train_start. And there's definitely something weird going on with tensors on different devices, and I think using one of the Lightning hooks would solve this since their backend will take care of moving tensors.

Regarding the sampling changes, would it be possible to include that in a separate PR? And then we can discuss that there. Thanks!

@vitkl
Copy link
Contributor Author

vitkl commented Apr 12, 2024

Just a brief reply. Happy to have a zoom call about pyro.

Pyro automatic variational distribution (Guide) doesn’t have any parameters until you do a first pass through the model and guide. When moving my code to multi-GPU training I found that this needs to be done in setup step of the Lightning workflow - otherwise parameters created on GPU don’t get moved between devices correctly - so it’s it would not in on_train_start. However, in the latest version the setup step also doesn’t work - as reported in the original issue. Moving the code to this function and calling it before using any Lightning workflow steps seems to solve the problem for cell2location and my other project.

Actually the reason for the errors might be resolved if you call both the model and guide with one batch (it’s possibly the issue with LDA model that uses a custom guide).

@martinkim0 martinkim0 added the P1 label Jul 12, 2024
args, kwargs = pl_module.module._get_fn_args_from_batch(tens)
pyro_guide(*args, **kwargs)
break
for tensors in dataloader:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to do next(iter(dataloader)) to get a single batch. I think still having the class makes sense. Within this class, there can be a manual_start function.

break


class PyroModelGuideWarmup(Callback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do those two classes exist in the first place?

@canergen
Copy link
Member

canergen commented Sep 5, 2024

Please split into two PRs. One for the warmup changes and one for the inference changes. This makes it easier to follow changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants