-
Notifications
You must be signed in to change notification settings - Fork 145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why is tc_loss in bTCVAE negative? #60
Comments
https://github.com/rtqichen/beta-tcvae/ calculates and in case of # minibatch stratified sampling, they do so in this codebase, shouldn't we also do (in case of NOT is_mss)
and in case of (is_mss)
|
Thanks @UserName-AnkitSisodia! Did you test it with these changes? |
Using some random matrices (code attached ), I used your code as well as Ricky Chen's code to compare what is happening. I found MWS MSS So, when I use your code with is_mss=true, then I get -ve tc_loss and with is_mss=false, I get -ve mi_loss and -ve tc_loss. Then I changed the _get_log_pz_qz_prodzi_qzCx function in your code to make it similar to Ricky Chen's code.
Then I get +ve losses for everything when is_mss=True but then I get -ve dw_kl_loss term. |
Awesome thanks for checking. Few comments: 1/ What do you mean by "+ve" and "-ve" ? What is ve ? 2/ Looking back at it it seems that I actually had the correct code and then incorporated the problem it in a late night push ( #43 ) Here's what I had before my changes: def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist,n_data, is_mss=False):
batch_size, hidden_dim = latent_sample.shape
# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
if not self.is_mss:
log_qz, log_prod_qzi = _minibatch_weighted_sampling(latent_dist,
latent_sample,
n_data)
else:
log_qz, log_prod_qzi = _minibatch_stratified_sampling(latent_dist,
latent_sample,
n_data)
return log_pz, log_qz, log_prod_qzi, log_q_zCx
def _minibatch_weighted_sampling(latent_dist, latent_sample, data_size):
"""
Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
weighted sampling.
Parameters
----------
latent_dist : tuple of torch.tensor
sufficient statistics of the latent dimension. E.g. for gaussian
(mean, log_var) each of shape : (batch_size, latent_dim).
latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick
shape : (batch_size, latent_dim).
data_size : int
Number of data in the training set
References
-----------
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
autoencoders." Advances in Neural Information Processing Systems. 2018.
"""
batch_size = latent_sample.size(0)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
log_prod_qzi = (torch.logsumexp(mat_log_qz, dim=1, keepdim=False) -
math.log(batch_size * data_size)).sum(dim=1)
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False
) - math.log(batch_size * data_size)
return log_qz, log_prod_qzi
def _minibatch_stratified_sampling(latent_dist, latent_sample, data_size):
"""
Estimates log q(z) and the log (product of marginals of q(z_j)) with minibatch
stratified sampling.
Parameters
-----------
latent_dist : tuple of torch.tensor
sufficient statistics of the latent dimension. E.g. for gaussian
(mean, log_var) each of shape : (batch_size, latent_dim).
latent_sample: torch.Tensor
sample from the latent dimension using the reparameterisation trick
shape : (batch_size, latent_dim).
data_size : int
Number of data in the training set
References
-----------
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
autoencoders." Advances in Neural Information Processing Systems. 2018.
"""
batch_size = latent_sample.size(0)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
log_iw_mat = log_importance_weight_matrix(batch_size, data_size).to(latent_sample.device)
log_qz = torch.logsumexp(log_iw_mat + mat_log_qz.sum(2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(log_iw_mat.view(batch_size, batch_size, 1) +
mat_log_qz, dim=1, keepdim=False).sum(1)
return log_qz, log_prod_qzi |
which is (I believe) exactly what you tested.
|
Yes, this makes the code exactly same. Once these changes are made, I get negative dw_kl_loss term in case of _minibatch_weighted_sampling. For _minibatch_stratified_sampling, I am getting all loss terms as positive. I tested on dsprites. |
and qualitatively do you see any differences? |
I didn't test that yet. I was just trying to see from the math/code where am I getting the error. |
Has this issue been solved ? Training on dSprites, I also get negative tc loss |
I also got the negative loss with the DSprites data |
tc loss |
disentangling-vae/results/btcvae_dsprites/train_losses.log
Line 5 in 535bbd2
The text was updated successfully, but these errors were encountered: