forked from ilya-shenbin/RecVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
118 lines (86 loc) · 4.35 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
from copy import deepcopy
import torch
from torch import nn
from torch.nn import functional as F
def swish(x):
return x.mul(torch.sigmoid(x))
def log_norm_pdf(x, mu, logvar):
return -0.5*(logvar + np.log(2 * np.pi) + (x - mu).pow(2) / logvar.exp())
class CompositePrior(nn.Module):
def __init__(self, hidden_dim, latent_dim, input_dim, mixture_weights=[3/20, 3/4, 1/10]):
super(CompositePrior, self).__init__()
self.mixture_weights = mixture_weights
self.mu_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False)
self.mu_prior.data.fill_(0)
self.logvar_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False)
self.logvar_prior.data.fill_(0)
self.logvar_uniform_prior = nn.Parameter(torch.Tensor(1, latent_dim), requires_grad=False)
self.logvar_uniform_prior.data.fill_(10)
self.encoder_old = Encoder(hidden_dim, latent_dim, input_dim)
self.encoder_old.requires_grad_(False)
def forward(self, x, z):
post_mu, post_logvar = self.encoder_old(x, 0)
stnd_prior = log_norm_pdf(z, self.mu_prior, self.logvar_prior)
post_prior = log_norm_pdf(z, post_mu, post_logvar)
unif_prior = log_norm_pdf(z, self.mu_prior, self.logvar_uniform_prior)
gaussians = [stnd_prior, post_prior, unif_prior]
gaussians = [g.add(np.log(w)) for g, w in zip(gaussians, self.mixture_weights)]
density_per_gaussian = torch.stack(gaussians, dim=-1)
return torch.logsumexp(density_per_gaussian, dim=-1)
class Encoder(nn.Module):
def __init__(self, hidden_dim, latent_dim, input_dim, eps=1e-1):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.ln2 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.ln3 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc4 = nn.Linear(hidden_dim, hidden_dim)
self.ln4 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc5 = nn.Linear(hidden_dim, hidden_dim)
self.ln5 = nn.LayerNorm(hidden_dim, eps=eps)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x, dropout_rate):
norm = x.pow(2).sum(dim=-1).sqrt()
x = x / norm[:, None]
x = F.dropout(x, p=dropout_rate, training=self.training)
h1 = self.ln1(swish(self.fc1(x)))
h2 = self.ln2(swish(self.fc2(h1) + h1))
h3 = self.ln3(swish(self.fc3(h2) + h1 + h2))
h4 = self.ln4(swish(self.fc4(h3) + h1 + h2 + h3))
h5 = self.ln5(swish(self.fc5(h4) + h1 + h2 + h3 + h4))
return self.fc_mu(h5), self.fc_logvar(h5)
class VAE(nn.Module):
def __init__(self, hidden_dim, latent_dim, input_dim):
super(VAE, self).__init__()
self.encoder = Encoder(hidden_dim, latent_dim, input_dim)
self.prior = CompositePrior(hidden_dim, latent_dim, input_dim)
self.decoder = nn.Linear(latent_dim, input_dim)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def forward(self, user_ratings, beta=None, gamma=1, dropout_rate=0.5, calculate_loss=True):
mu, logvar = self.encoder(user_ratings, dropout_rate=dropout_rate)
z = self.reparameterize(mu, logvar)
x_pred = self.decoder(z)
if calculate_loss:
if gamma:
norm = user_ratings.sum(dim=-1)
kl_weight = gamma * norm
elif beta:
kl_weight = beta
mll = (F.log_softmax(x_pred, dim=-1) * user_ratings).sum(dim=-1).mean()
kld = (log_norm_pdf(z, mu, logvar) - self.prior(user_ratings, z)).sum(dim=-1).mul(kl_weight).mean()
negative_elbo = -(mll - kld)
return (mll, kld), negative_elbo
else:
return x_pred
def update_prior(self):
self.prior.encoder_old.load_state_dict(deepcopy(self.encoder.state_dict()))