-
Notifications
You must be signed in to change notification settings - Fork 28
/
vime_self.py
71 lines (55 loc) · 2.21 KB
/
vime_self.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain (VIME) Codebase.
Reference: Jinsung Yoon, Yao Zhang, James Jordon, Mihaela van der Schaar,
"VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain,"
Neural Information Processing Systems (NeurIPS), 2020.
Paper link: TBD
Last updated Date: October 11th 2020
Code author: Jinsung Yoon ([email protected])
-----------------------------
vime_self.py
- Self-supervised learning parts of the VIME framework
- Using unlabeled data to train the encoder
"""
# Necessary packages
from keras.layers import Input, Dense
from keras.models import Model
from keras import models
from vime_utils import mask_generator, pretext_generator
def vime_self (x_unlab, p_m, alpha, parameters):
"""Self-supervised learning part in VIME.
Args:
x_unlab: unlabeled feature
p_m: corruption probability
alpha: hyper-parameter to control the weights of feature and mask losses
parameters: epochs, batch_size
Returns:
encoder: Representation learning block
"""
# Parameters
_, dim = x_unlab.shape
epochs = parameters['epochs']
batch_size = parameters['batch_size']
# Build model
inputs = Input(shape=(dim,))
# Encoder
h = Dense(int(dim), activation='relu')(inputs)
# Mask estimator
output_1 = Dense(dim, activation='sigmoid', name = 'mask')(h)
# Feature estimator
output_2 = Dense(dim, activation='sigmoid', name = 'feature')(h)
model = Model(inputs = inputs, outputs = [output_1, output_2])
model.compile(optimizer='rmsprop',
loss={'mask': 'binary_crossentropy',
'feature': 'mean_squared_error'},
loss_weights={'mask':1, 'feature':alpha})
# Generate corrupted samples
m_unlab = mask_generator(p_m, x_unlab)
m_label, x_tilde = pretext_generator(m_unlab, x_unlab)
# Fit model on unlabeled data
model.fit(x_tilde, {'mask': m_label, 'feature': x_unlab},
epochs = epochs, batch_size= batch_size)
# Extract encoder part
layer_name = model.layers[1].name
layer_output = model.get_layer(layer_name).output
encoder = models.Model(inputs=model.input, outputs=layer_output)
return encoder