Skip to content

1602BayesDropout

Petr Baudis edited this page Mar 9, 2016 · 10 revisions

1602 Bayesian Dropout Experiments

In Feb 2016, Keras introduced Bayesian dropout based on http://arxiv.org/abs/1512.05287 . Some bugfixes unfortunately also followed (marked as 1rnn compared to 0rnn in some runs).

Hypothesis: Bayesian dropout is awesome and will improve results compared to no (or standard) dropout.

We'll verify it on the rnn model which is simple and has only a well defined place where dropout is applied.

Run naming scheme - i is input dropout, first w, then e; d is inner dropout, first regular, then dropoutfix_inp and dropoutfix_rec. The digits (hex) denote fraction parts (14 == 1/4, 1a == 1/10 etc.).

yodaqa curatedv2 dataset (smaller)

Result: ???

python tools/anssel_train.py rnn data/anssel/yodaqa/curatedv2-training.csv data/anssel/yodaqa/curatedv2-val.csv inp_e_dropout=1/2 inp_w_dropout=1/2
	ay_rnn_i1212 10613674.arien.ics.muni.cz
Epoch 11/16
20000/20000 [==============================] - 189s - loss: 0.1401 - val_loss: 0.2315                                       val mrr 0.341237
Val MRR: 0.365388

	ay_rnn_i0012 10613685.arien.ics.muni.cz
Epoch 9/16
20000/20000 [==============================] - 182s - loss: 0.1312 - val_loss: 0.2426                                       val mrr 0.249609
Val MRR: 0.367208

	ay_rnn_i0013 10613705.arien.ics.muni.cz
Epoch 8/16
20000/20000 [==============================] - 182s - loss: 0.1138 - val_loss: 0.2849                                       val mrr 0.296654
Val MRR: 0.371005

	ay_rnn_i0023 10613706.arien.ics.muni.cz
Epoch 10/16
20000/20000 [==============================] - 185s - loss: 0.1525 - val_loss: 0.2230                                       val mrr 0.326151
Val MRR: 0.396984

rnn-169ec21c3bba4f6b ay_rnn_i0012d001214 10613701.arien.ics.muni.cz
Epoch 8/16
20000/20000 [==============================] - 196s - loss: 0.1576 - val_loss: 0.2400                                       val mrr 0.310440
Val MRR: 0.410301
Val MRR: 0.411793
Val MRR: 0.381539
Val MRR: 0.356278
Val MRR: 0.393808
Val MRR: 0.413956

	ay_rnn_i0012d001212 10613703.arien.ics.muni.cz
Epoch 13/16
20000/20000 [==============================] - 201s - loss: 0.1316 - val_loss: 0.2650                                       val mrr 0.347770
Val MRR: 0.362450

	ay_rnn_i0012d001412 10613709.arien.ics.muni.cz
Epoch 11/16
20000/20000 [==============================] - 198s - loss: 0.2247 - val_loss: 0.2158                                       val mrr 0.228685
Val MRR: 0.292240

	ay_rnn_i1512d151412 10613710.arien.ics.muni.cz
Epoch 10/16
20000/20000 [==============================] - 199s - loss: 0.1500 - val_loss: 0.2441                                       val mrr 0.311325
Val MRR: 0.395860

ay_rnn_i0013d001212 stability check (cross with the 1602Stats experiment):
rnn--3c89b2d533151ccc   Val MRR: 0.399063
rnn--23d44845606d86f    Val MRR: 0.363348
rnn--3582cd0486099c59   Val MRR: 0.399063
rnn--23d44845606d86f    Val MRR: 0.363348
rnn--3582cd0486099c59   Val MRR: 0.343782
rnn-180a1755b94f5109    Val MRR: 0.352646
rnn-69e54a7387c40111    Val MRR: 0.326337
rnn--40ac0770d7b61c20   Val MRR: 0.343964
rnn--3d83947cd56a2c19   Val MRR: 0.378791
rnn--6bc01b5c8f6e0623   Val MRR: 0.334495
rnn--2ef2fface53d36bf   Val MRR: 0.325264
rnn--4555ade20dd9e920   Val MRR: 0.323975
still overfits:
Epoch 9/16
20000/20000 [==============================] - 204s - loss: 0.1364 - val_loss: 0.2685                                       val mrr 0.252401
Epoch 10/16
20000/20000 [==============================] - 202s - loss: 0.1278 - val_loss: 0.2930                                       val mrr 0.288132
Epoch 11/16
20000/20000 [==============================] - 202s - loss: 0.1191 - val_loss: 0.2405                                       val mrr 0.262947

	ay_rnn_i0023d002323 10627954.arien.ics.muni.cz
Val MRR: 0.332978
	ay_rnn_i0014d002323 10627975.arien.ics.muni.cz
Val MRR: 0.353632
	ay_rnn_i0034d003434 10627963.arien.ics.muni.cz
Val MRR: 0.251581
	ay_rnn_i0014d003434 10627969.arien.ics.muni.cz
Val MRR: 0.358861

So far, we have shown that:

  • To be able to draw conclusions, multiple runs are required (fairly large result instability even with Bayesian dropout)
  • rnn configuration ay_rnn_i0012d001214 looks promising
  • Small regular dropout doesn't seem harmful
  • Very high bayes dropout seems bad

Computing statistics (TODO bonferroni coefficients when comparing):

def stat(r):
    bar = ss.t.isf((1-0.95)/2, len(r)-1) * np.std(r)/np.sqrt(len(r))
    print('%f (95%% [%f, %f])' % (np.mean(r), np.mean(r) - bar, np.mean(r) + bar))

64x RNN with no dropout - 0.341914 (95% [0.335095, 0.348732]):

10658668.arien.ics.muni.cz.ay_1rnnd0 etc. (4x),
10676217.arien.ics.muni.cz.ay_1rnn_i0d0 etc.
10676263.arien.ics.muni.cz.ay_1rnn_i0d0 etc.
10679582.arien.ics.muni.cz.ay_1rnn_i0d0 etc.
[0.307231, 0.322597, 0.347030, 0.309661, 0.375453, 0.332430, 0.354851, 0.299810, 0.380887, 0.307368, 0.344801, 0.366382, 0.403766, 0.364815, 0.318076, 0.295354, 0.345617, 0.356424, 0.346753, 0.340065, 0.326070, 0.325360, 0.325105, 0.358304, 0.335235, 0.336816, 0.364727, 0.311757, 0.382863, 0.289366, 0.365342, 0.346157, ] + [0.341826, 0.344752, 0.367301, 0.360583, 0.338028, 0.327622, 0.351288, 0.369168, 0.360573, 0.391714, 0.282086, 0.321398, 0.355703, 0.348531, 0.380449, 0.345309, 0.327486, 0.309876, 0.305115, 0.365353, 0.381869, 0.379832, 0.349501, 0.347066, 0.324666, 0.375235, 0.296426, 0.337168, 0.335412, 0.331074, 0.301670, ]

32x RNN with small non-Bayesian input-embedding-only dropout - 0.357776 (95% [0.349315, 0.366237]):

10679550.arien.ics.muni.cz.ay_1rnn_i14d0 etc.
[0.338217, 0.364905, 0.315362, 0.326124, 0.340052, 0.342657, 0.353875, 0.400332, 0.368478, 0.377953, 0.335330, 0.318467, 0.355779, 0.342826, 0.356464, 0.353479, 0.385363, 0.387733, 0.355364, 0.362987, 0.345662, 0.364570, 0.426060, 0.342305, 0.358173, 0.355512, 0.374328, 0.341010, 0.347028, 0.358308, 0.354759, 0.399366, ]

32x RNN with small non-Bayesian dropout - 0.349198 (95% [0.340539, 0.357858]):

10676152.arien.ics.muni.cz.ay_1rnn_i14d14 etc.
10677631.arien.ics.muni.cz.ay_1rnn_i14d14 etc.
[0.358731, 0.345882, 0.319929, 0.362027, 0.313022, 0.404190, 0.351598, 0.328947, ] + [0.351770, 0.354676, 0.304611, 0.322354, 0.382736, 0.332211, 0.330184, 0.342133, 0.365452, 0.352409, 0.368504, 0.333728, 0.368692, 0.386257, 0.315388, 0.347077, 0.368641, 0.374750, 0.304989, 0.380117, 0.356715, 0.340518, 0.345060, 0.361053, ]

32x RNN with non-Bayesian dropout - 0.365131 (95% [0.356652, 0.373611]):

10679614.arien.ics.muni.cz.ay_1rnn_i12d12
[0.325909, 0.360563, 0.391624, 0.357004, 0.317877, 0.355094, 0.407481, 0.359131, 0.330650, 0.380453, 0.367553, 0.353810, 0.373268, 0.385657, 0.363275, 0.388254, 0.357876, 0.377382, 0.362747, 0.336688, 0.399590, 0.335586, 0.336083, 0.350715, 0.371391, 0.376305, 0.396156, 0.419271, 0.359192, 0.346788, 0.367768, 0.373067, ]

29x RNN with large non-Bayesian dropout ay_1rnn_i23d23 - 0.385323 (95% [0.372930, 0.397716]):

10680751.arien.ics.muni.cz.ay_1rnn_i23d23 etc.
[0.343581, 0.399249, 0.420247, 0.417747, 0.402610, 0.376662, 0.426974, 0.361336, 0.348797, 0.416641, 0.353445, 0.371555, 0.396943, 0.347177, 0.455542, 0.380509, 0.330139, 0.417342, 0.430644, 0.362429, 0.414006, 0.347151, 0.337468, 0.366264, 0.427821, 0.395221, 0.369467, 0.377539, 0.379852, ]

32x RNN with Bayesian dropout ay_rnn_i0012d001214 - 0.360404 (95% [0.346454, 0.374355]):

10676160.arien.ics.muni.cz.ay_1rnn_i0012d001214 etc.
10677682.arien.ics.muni.cz.ay_1rnn_i0012d001214 etc.
[0.344790, 0.392762, 0.347871, 0.389196, 0.342889, 0.388482, 0.364078, 0.366837, ] + [0.273012, 0.403682, 0.363418, 0.405802, 0.344355, 0.356530, 0.250001, 0.350193, 0.343800, 0.378575, 0.372497, 0.336383, 0.350030, 0.349110, 0.356529, 0.384694, 0.370370, 0.365075, 0.370023, 0.313609, 0.375597, 0.345670, 0.483959, 0.353123, ]

35xRNN with larger Bayesian dropout i1412d001212 - 0.358243 (95% [0.350701, 0.365785]):

10677706.arien.ics.muni.cz.ay_1rnn_i1412d001212 etc.
[0.361612, 0.373292, 0.347602, 0.356730, 0.356264, 0.376356, 0.350416, 0.347349, 0.363616, 0.357525, 0.408988, 0.330629, 0.399791,] + [0.322100, 0.364582, 0.397343, 0.354578, 0.363808, 0.339027, 0.356916, 0.400061, 0.340568, 0.369741, 0.363005, 0.347732, 0.358376, 0.334798, 0.331826, 0.335582, 0.328133, 0.362974, 0.335881, 0.360270, 0.400516, 0.340520, ]

funny spike:
Epoch 6/16
15256/15256 [==============================] - 161s - loss: 0.2129 - val_loss: 0.2132                                       val mrr 0.376923
Epoch 7/16
15256/15256 [==============================] - 162s - loss: 0.1984 - val_loss: 0.2138                                       val mrr 0.483959
Epoch 8/16
15256/15256 [==============================] - 162s - loss: 0.1874 - val_loss: 0.2402                                       val mrr 0.405460

31x non-Bayesian + larger-Bayesian dropout combo ay_1rnn_i1212d121212 - 0.352122 (95% [0.342436, 0.361809]):

10680714.arien.ics.muni.cz.ay_1rnn_i1212d121212 etc.
[0.350931, 0.403541, 0.355157, 0.360650, 0.347558, 0.317763, 0.340936, 0.314052, 0.369013, 0.362852, 0.379455, 0.369208, 0.363547, 0.338830, 0.328033, 0.357893, 0.315735, 0.344831, 0.312284, 0.343195, 0.333577, 0.398691, 0.338599, 0.431639, 0.323473, 0.360827, 0.341913, 0.352072, 0.359825, 0.332215, 0.367495, ]

We want this other information:

  • Big dropout check (wip)
  • Big normal dropout + big Bayesian dropout check (wip)
  • Big normal dropout + big Bayesian dropout without recurrent dropout (which is often not recommended) (wip)
  • Is smaller dropout better for yoda-large, since dropout is ineffective for ubuntu?

wang

8x RNN with no dropout:

10676221.arien.ics.muni.cz.aw_1rnn_i0d0 etc.
[0.841589, 0.864377, 0.862564, 0.869369, 0.880000, 0.867366, 0.884359, 0.851474, ]

8x RNN with small normal dropout:

[0.857656, 0.865641, 0.914416, 0.869487, 0.845425, 0.865321, 0.847546, 0.869872, ]

Ubuntu dataset (larger)

Result: Hypothesis false. No dropout is best dropout for Ubuntu!

Full model

The rnnt_sA was best in generic ubuntu model tuning.

 sdim=1 pdim=1 "pact='tanh'" ptscorer=B.dot_ptscorer

Baseline (NO dropout at all):

python tools/ubuntu_train.py rnn data/anssel/ubuntu/v2-vocab.pickle data/anssel/ubuntu/v2-trainset.pickle data/anssel/ubuntu/v2-valset.pickle sdim=1 pdim=1 "pact='tanh'" ptscorer=B.dot_ptscorer (implicitly also dropout=0 inp_e_dropout=0)
rnn-540d2a87a380b377 rnnt_sA 10596014.arien.ics.muni.cz
Epoch 17/32
200064/200000 [==============================] - 1999s - loss: 0.1256                                                         val mrr 0.742399
Val MRR: 0.786334
Val 2-R@1: 0.910429
Val 10-R@1: 0.671472  10-R@2: 0.805010  10-R@5: 0.953170                                         ***BEST***

Input dropout in general (implicit: dropout=0):

python tools/ubuntu_train.py rnn data/anssel/ubuntu/v2-vocab.pickle data/anssel/ubuntu/v2-trainset.pickle data/anssel/ubuntu/v2-valset.pickle sdim=1 pdim=1 "pact='tanh'" ptscorer=B.dot_ptscorer inp_e_dropout=1/4 inp_w_dropout=1/4
rnn-5020d5826692db97 rnn_sAi1414 10613663.arien.ics.muni.cz
Epoch 20/32
200064/200000 [==============================] - 2307s - loss: 0.5181                                                         val mrr 0.670752
Val MRR: 0.672963
Val 2-R@1: 0.819376
Val 10-R@1: 0.540746  10-R@2: 0.658640  10-R@5: 0.861145

rnn_sAi1412 10613664.arien.ics.muni.cz
Epoch 30/32
200064/200000 [==============================] - 2283s - loss: 0.5373                                                         val mrr 0.654066
Val MRR: 0.660758
Val 2-R@1: 0.809407
Val 10-R@1: 0.526585  10-R@2: 0.647239  10-R@5: 0.843712

rnn_sAi1212 10613665.arien.ics.muni.cz
Epoch 32/32
200064/200000 [==============================] - 2329s - loss: 0.5563                                                         val mrr 0.648415
Val MRR: 0.654191
Val 2-R@1: 0.805675
Val 10-R@1: 0.520757  10-R@2: 0.632464  10-R@5: 0.839417

General dropout:

rnn_sAi001ad001a1a 10628061.arien.ics.muni.cz
Epoch 16/32
200064/200000 [==============================] - 2760s - loss: 0.2874                                                         val mrr 0.749541
Predict&Eval (best epoch)
Val MRR: 0.774093
Val 2-R@1: 0.903221
Val 10-R@1: 0.653170  10-R@2: 0.795297  10-R@5: 0.948415

rnn_sAi0015d001515 10628048.arien.ics.muni.cz
Epoch 8/32
200064/200000 [==============================] - 2773s - loss: 0.4714                                                         val mrr 0.741423
(training diverged in next epoch)
Predict&Eval (best epoch)
Val MRR: 0.741423
Val 2-R@1: 0.878834
Val 10-R@1: 0.616002  10-R@2: 0.748620  10-R@5: 0.924949

10631135.arien.ics.muni.cz.rnn_sAi0d0
Epoch 4/32
200064/200000 [==============================] - 2642s - loss: 0.4323                                                         val mrr 0.759431
then diverges

10631137.arien.ics.muni.cz.rnn_sAi0015d001a1a
diverges early

10631138.arien.ics.muni.cz.rnn_sAi0015d00151a
Epoch 12/32
200064/200000 [==============================] - 2856s - loss: 0.4724                                                         val mrr 0.738685
then diverges

10631139.arien.ics.muni.cz.rnn_sAi0015d001a15
Epoch 9/32
200064/200000 [==============================] - 2736s - loss: 0.4466                                                         val mrr 0.764238
then diverges

10631140.arien.ics.muni.cz.rnn_sAi1415d141a1a
diverges early

==> 10637658.arien.ics.muni.cz.0rnn_tA_d0i0 <==
200064/200000 [==============================] - 2333s - loss: 0.3449                                                         val mrr 0.776052
Predict&Eval (best epoch)
Val MRR: 0.776052
Val 2-R@1: 0.902096
Val 10-R@1: 0.657618  10-R@2: 0.793252  10-R@5: 0.949029

==> 10637659.arien.ics.muni.cz.0rnn_tA_d0i0 <==
200064/200000 [==============================] - 2318s - loss: 0.1229                                                         val mrr 0.745733
Predict&Eval (best epoch)
Val MRR: 0.783147
Val 2-R@1: 0.906544
Val 10-R@1: 0.667127  10-R@2: 0.802914  10-R@5: 0.952710

==> 10637877.arien.ics.muni.cz.0rnn_tA_d12i12 <==
(diverges)
Predict&Eval (best epoch)
Val MRR: 0.753497
Val 2-R@1: 0.892382
Val 10-R@1: 0.624847  10-R@2: 0.768814  10-R@5: 0.944274

==> 10637879.arien.ics.muni.cz.0rnn_tA_d12i14 <==
(diverges)
Predict&Eval (best epoch)
Val MRR: 0.715625
Val 2-R@1: 0.861196
Val 10-R@1: 0.584611  10-R@2: 0.715900  10-R@5: 0.907771

==> 10638098.arien.ics.muni.cz.0rnn_tA_d12i00 <==
(diverges)
Predict&Eval (best epoch)
Val MRR: 0.757757
Val 2-R@1: 0.894172
Val 10-R@1: 0.632924  10-R@2: 0.773824  10-R@5: 0.940235

"Very small" model

The rnnm_sA model (also called rn8_m later) is much smaller and with just spad=80. It is only slightly worse than the full model, but about 3x faster to train.

"pact='tanh'" sdim=1/6 pdim=1/6 ptscorer=B.dot_ptscorer

Baseline:

python tools/ubuntu_train2.py rnn data/anssel/ubuntu/v2-vocab.pickle data/anssel/ubuntu/v2-trainset.pickle data/anssel/ubuntu/v2-valset.pickle "pact='tanh'" sdim=1/6 pdim=1/6 ptscorer=B.dot_ptscorer (implicit dropout=0 inp_e_dropout=0)
rnn-3b38f6cc1e91ec9e
Epoch 11/32
200064/200000 [==============================] - 767s - loss: 0.2440                                                          val mrr 0.729863
data/anssel/ubuntu/v2-valset.pickle MRR: 0.748693
data/anssel/ubuntu/v2-valset.pickle 2-R@1: 0.885123
data/anssel/ubuntu/v2-valset.pickle 10-R@1: 0.622342  10-R@2: 0.762321  10-R@5: 0.930930

This is pretty stable, about 1% extreme-to-extreme spread.

Experiments:

==> 10637656.arien.ics.muni.cz.0rnn_mA_d0i0 <==
200064/200000 [==============================] - 796s - loss: 0.1390                                                          val mrr 0.715701
Predict&Eval (best epoch)
Val MRR: 0.754669
Val 2-R@1: 0.892127
Val 10-R@1: 0.629090  10-R@2: 0.770757  10-R@5: 0.936094

==> 10637657.arien.ics.muni.cz.0rnn_mA_d0i0 <==
200064/200000 [==============================] - 779s - loss: 0.1405                                                          val mrr 0.708413
Predict&Eval (best epoch)
Val MRR: 0.750941
Val 2-R@1: 0.889519
Val 10-R@1: 0.625767  10-R@2: 0.762526  10-R@5: 0.935992

==> 10639375.arien.ics.muni.cz.0rn8_mAd0 <==
200064/200000 [==============================] - 774s - loss: 0.1414                                                          val mrr 0.706607
Predict&Eval (best epoch)
Val MRR: 0.752353
Val 2-R@1: 0.888753
Val 10-R@1: 0.627352  10-R@2: 0.765031  10-R@5: 0.936043

==> 10628079.arien.ics.muni.cz.rnnm_sAi001ad001a1a <==
200064/200000 [==============================] - 836s - loss: 0.3031                                                          val mrr 0.717387
Predict&Eval (best epoch)
Val MRR: 0.743645
Val 2-R@1: 0.881595
Val 10-R@1: 0.616207  10-R@2: 0.754959  10-R@5: 0.929755

==> 10628080.arien.ics.muni.cz.rnnm_sAi0015d001515 <==
200064/200000 [==============================] - 824s - loss: 0.3652                                                          val mrr 0.709904
Predict&Eval (best epoch)
Val MRR: 0.723464
Val 2-R@1: 0.865337
Val 10-R@1: 0.593405  10-R@2: 0.729908  10-R@5: 0.909509

==> 10628081.arien.ics.muni.cz.rnnm_sAi0015d001212 <==
200064/200000 [==============================] - 834s - loss: 0.4452                                                          val mrr 0.683098
Predict&Eval (best epoch)
Val MRR: 0.693378
Val 2-R@1: 0.842127
Val 10-R@1: 0.558538  10-R@2: 0.692689  10-R@5: 0.884458

==> 10628082.arien.ics.muni.cz.rnnm_sAi0013d001212 <==
200064/200000 [==============================] - 838s - loss: 0.4949                                                          val mrr 0.668988
Predict&Eval (best epoch)
Val MRR: 0.676498
Val 2-R@1: 0.820757
Val 10-R@1: 0.544581  10-R@2: 0.663855  10-R@5: 0.863088

==> 10628083.arien.ics.muni.cz.rnnm_sAi0013d002323 <==
200064/200000 [==============================] - 843s - loss: 0.5478                                                          val mrr 0.662266
Predict&Eval (best epoch)
Val MRR: 0.665297
Val 2-R@1: 0.813241
Val 10-R@1: 0.529499  10-R@2: 0.652965  10-R@5: 0.854397

==> 10629705.arien.ics.muni.cz.rnnm_sAi0015d001a1a <==
200064/200000 [==============================] - 817s - loss: 0.3824                                                          val mrr 0.723031
Predict&Eval (best epoch)
Val MRR: 0.727865
Val 2-R@1: 0.870194
Val 10-R@1: 0.599029  10-R@2: 0.734407  10-R@5: 0.912526

==> 10629717.arien.ics.muni.cz.rnnm_sAi0015d00151a <==
200064/200000 [==============================] - 824s - loss: 0.3391                                                          val mrr 0.710301
Predict&Eval (best epoch)
Val MRR: 0.731129
Val 2-R@1: 0.873620
Val 10-R@1: 0.602505  10-R@2: 0.737219  10-R@5: 0.917996

==> 10629718.arien.ics.muni.cz.rnnm_sAi0015d001a15 <==
200064/200000 [==============================] - 835s - loss: 0.3392                                                          val mrr 0.712082
Predict&Eval (best epoch)
Val MRR: 0.734751
Val 2-R@1: 0.876585
Val 10-R@1: 0.606595  10-R@2: 0.742587  10-R@5: 0.920450

==> 10629719.arien.ics.muni.cz.rnnm_sAi001ad001515 <==
200064/200000 [==============================] - 851s - loss: 0.2918                                                          val mrr 0.712226
Predict&Eval (best epoch)
Val MRR: 0.733117
Val 2-R@1: 0.877198
Val 10-R@1: 0.603272  10-R@2: 0.744172  10-R@5: 0.919734

==> 10629720.arien.ics.muni.cz.rnnm_sAi001ad001a00 <==
200064/200000 [==============================] - 847s - loss: 0.2853                                                          val mrr 0.720628
Predict&Eval (best epoch)
Val MRR: 0.743576
Val 2-R@1: 0.883436
Val 10-R@1: 0.617791  10-R@2: 0.751125  10-R@5: 0.928988

==> 10630520.arien.ics.muni.cz.rnnm_sAi0d0 <==
200064/200000 [==============================] - 777s - loss: 0.1393                                                          val mrr 0.708467
Predict&Eval (best epoch)
Val MRR: 0.750134
Val 2-R@1: 0.885838
Val 10-R@1: 0.625102  10-R@2: 0.762372  10-R@5: 0.931902

==> 10631132.arien.ics.muni.cz.rnnm_sAi1200d120000 <==
200064/200000 [==============================] - 818s - loss: 0.5020                                                          val mrr 0.649190
Predict&Eval (best epoch)
Val MRR: 0.651596
Val 2-R@1: 0.808691
Val 10-R@1: 0.511708  10-R@2: 0.633947  10-R@5: 0.848108

==> 10631133.arien.ics.muni.cz.rnnm_sAi1400d140000 <==
200064/200000 [==============================] - 871s - loss: 0.2410                                                          val mrr 0.717085
Predict&Eval (best epoch)
Val MRR: 0.742678
Val 2-R@1: 0.882515
Val 10-R@1: 0.614519  10-R@2: 0.754550  10-R@5: 0.929755

==> 10631134.arien.ics.muni.cz.rnnm_sAi3400d340000 <==
200064/200000 [==============================] - 854s - loss: 0.5765                                                          val mrr 0.492365
Predict&Eval (best epoch)
Val MRR: 0.531451
Val 2-R@1: 0.694836
Val 10-R@1: 0.377658  10-R@2: 0.492536  10-R@5: 0.712219

==> 10638093.arien.ics.muni.cz.0rnn_mA_d14i14 <==
200064/200000 [==============================] - 746s - loss: 0.3045                                                          val mrr 0.717984
Predict&Eval (best epoch)
Val MRR: 0.734787
Val 2-R@1: 0.878681
Val 10-R@1: 0.605266  10-R@2: 0.742945  10-R@5: 0.923262

==> 10638096.arien.ics.muni.cz.0rnn_mA_d12i14 <==
200064/200000 [==============================] - 768s - loss: 0.3087                                                          val mrr 0.682688
Predict&Eval (best epoch)
Val MRR: 0.705960
Val 2-R@1: 0.853016
Val 10-R@1: 0.573211  10-R@2: 0.705777  10-R@5: 0.896881

==> 10638097.arien.ics.muni.cz.0rnn_mA_d12i00 <==
200064/200000 [==============================] - 749s - loss: 0.3058                                                          val mrr 0.681843
Predict&Eval (best epoch)
Val MRR: 0.702096
Val 2-R@1: 0.848926
Val 10-R@1: 0.568763  10-R@2: 0.700613  10-R@5: 0.894581

Large model with spad=80

The rn8_t models represent a larger model wrt. the RNN, which may be interesting for dropout experiments, but with spad=80 to make learning still manageable:

 sdim=2 pdim=1 "pact='tanh'" ptscorer=B.dot_ptscorer

Baseline (repeated, some measurements cut off early by job system or divergence):

Epoch 17/32
200064/200000 [==============================] - 2448s - loss: 0.1755                                                         val mrr 0.738215
Predict&Eval (best epoch)
Val MRR: 0.766637
Val 2-R@1: 0.898160
Val 10-R@1: 0.644121  10-R@2: 0.786196  10-R@5: 0.943405

Epoch 4/32
200064/200000 [==============================] - 2423s - loss: 0.4320                                                         val mrr 0.761111

Epoch 9/32
200064/200000 [==============================] - 2389s - loss: 0.3566                                                         val mrr 0.773375

Epoch 10/32
200064/200000 [==============================] - 2467s - loss: 0.3566                                                         val mrr 0.778220

Epoch 4/32
200064/200000 [==============================] - 2436s - loss: 0.4363                                                         val mrr 0.760970

Epoch 4/32
200064/200000 [==============================] - 2421s - loss: 0.4337                                                         val mrr 0.758342

Experiments:

==> 10639342.arien.ics.muni.cz.0rn8_tAs2_d0i0014 <==
200064/200000 [==============================] - 2426s - loss: 0.1720                                                         val mrr 0.736546
Predict&Eval (best epoch)
Val MRR: 0.776231
Val 2-R@1: 0.902658
Val 10-R@1: 0.657720  10-R@2: 0.793712  10-R@5: 0.948108

==> 10639344.arien.ics.muni.cz.0rn8_tAs2_d140000i0 <==
200064/200000 [==============================] - 2444s - loss: 0.5671                                                         val mrr 0.668038
Predict&Eval (best epoch)
Val MRR: 0.739897
Val 2-R@1: 0.879499
Val 10-R@1: 0.613599  10-R@2: 0.747648  10-R@5: 0.923517

==> 10639345.arien.ics.muni.cz.0rn8_tAs2_d001400i0 <==
200064/200000 [==============================] - 2426s - loss: 7.9537                                                         val mrr 0.100000
Predict&Eval (best epoch)
Val MRR: 0.763297
Val 2-R@1: 0.894939
Val 10-R@1: 0.642178  10-R@2: 0.776278  10-R@5: 0.941411

==> 10639346.arien.ics.muni.cz.0rn8_tAs2_d000014i0 <==
200064/200000 [==============================] - 2452s - loss: 0.2336                                                         val mrr 0.721702
Predict&Eval (best epoch)
Val MRR: 0.767557
Val 2-R@1: 0.899438
Val 10-R@1: 0.646677  10-R@2: 0.782669  10-R@5: 0.944018

==> 10648978.arien.ics.muni.cz.0rn8_tAs2_d0i0012 <==
200064/200000 [==============================] - 2394s - loss: 0.1758                                                         val mrr 0.731808
Predict&Eval (best epoch)
Val MRR: 0.767523
Val 2-R@1: 0.898160
Val 10-R@1: 0.644990  10-R@2: 0.787986  10-R@5: 0.943814

==> 10648982.arien.ics.muni.cz.0rn8_tAs2_d001414i0014 <==
200064/200000 [==============================] - 2393s - loss: 0.1702                                                         val mrr 0.736441
Predict&Eval (best epoch)
Val MRR: 0.767466
Val 2-R@1: 0.896779
Val 10-R@1: 0.646830  10-R@2: 0.783742  10-R@5: 0.943814

==> 10648984.arien.ics.muni.cz.0rn8_tAs2_d001414i0 <==
200064/200000 [==============================] - 2387s - loss: 0.2995                                                         val mrr 0.752159
Predict&Eval (best epoch)
Val MRR: 0.768176
Val 2-R@1: 0.895961
Val 10-R@1: 0.647137  10-R@2: 0.785123  10-R@5: 0.943149

==> 10648988.arien.ics.muni.cz.0rn8_tAs2_d001214i0 <==
200064/200000 [==============================] - 2342s - loss: 0.1617                                                         val mrr 0.738756
Predict&Eval (best epoch)
Val MRR: 0.775496
Val 2-R@1: 0.901585
Val 10-R@1: 0.656646  10-R@2: 0.795399  10-R@5: 0.947648

==> 10648990.arien.ics.muni.cz.0rn8_tAs2_d001412i0 <==
200064/200000 [==============================] - 2402s - loss: 0.1686                                                         val mrr 0.737475
Predict&Eval (best epoch)
Val MRR: 0.770147
Val 2-R@1: 0.902556
Val 10-R@1: 0.650613  10-R@2: 0.785838  10-R@5: 0.944683

With the Keras RNN input dropout bugfix:

10658664.arien.ics.muni.cz.1rn8_tAs2_d001400i0
Epoch 9/32
200064/200000 [==============================] - 2381s - loss: 0.4248                                                         val mrr 0.768909
(diverges later)

10658665.arien.ics.muni.cz.1rn8_tAs2_d001412i0
Epoch 20/32
200064/200000 [==============================] - 2603s - loss: 0.3610                                                         val mrr 0.737464
Predict&Eval (best epoch)
Val MRR: 0.745073
Val 2-R@1: 0.883027
Val 10-R@1: 0.617076  10-R@2: 0.756851  10-R@5: 0.932720

So that doesn't help with anything.

Clone this wiki locally