-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to change covariance matrix type for GMM class #50
base: master
Are you sure you want to change the base?
Conversation
) | ||
|
||
|
||
def _determine_shapes(components, features, covariance_type, tied, cov_rank): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following pattern would allow to reduce code duplication
leading = 1 if tied else components
if covariance_type == 'full':
shapes.extend([
(leading, features),
(leading, features * (features - 1) // 2),
])
elif ...
Hello @dominik-strutz, I quickly went over the code, and it looks nice! I think we should add some tests however, maybe in a new Are you also planning to improve the initialization as well? |
Hi @francois-rozet, Yes, I am happy to write some tests. I'm also happy to try to improve the initialization. Following What is your opinion on how to structure the initialization? I think it would be beneficial to keep the I will give the initialisation a try and let you know how it goes. P.S: I have no idea why the pre-commit hook fails. I used |
I think a good way to handle the conditional case would be to make the weight
I agree that a separate method could be appropriate, similar to the
I pulled your branch and
Maybe you were not at the root? My version of ruff is |
@dominik-strutz Do you still plan on contributing this PR? |
Yes, I still like to contribute but haven't found much free time to do it recently. I have implemented most of the initialization methods for the unconditional case, but it still needs to be polished up and tested. The extension for the conditional case shouldn't take too long afterwards. If you or someone else wants to continue this sooner, I'm happy to push an intermediary commit of everything I have so far. |
No problem, take your time! I am currently updating a few things and wanted to know if I should wait for this PR for the next minor release. |
This PR adds changes to the
zuko.flows.mixture.GMM
class, which allow the user to change the type of the covariance matrix used for each of the Gaussian components of the mixture.The options added are
covariance_type
, which allows to change the type of the covariance matricestied
a switch which allows to control if covariance matrices are tied between componentscov_rank
the rank of the low-rank covariance matrix whencovariance_type
is 'lowrank'Since the construction of the shapes got quite long I moved this part in its own function.
Below is an illustration of the effect these different choices have for a mixture of 3 two-dimensional Gaussians.