This repository has been archived by the owner on Dec 17, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
heteroscedastic.py
1983 lines (1680 loc) · 76.9 KB
/
heteroscedastic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2021 The Edward2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library of methods to compute heteroscedastic classification predictions."""
import random
import tensorflow as tf
import tensorflow_probability as tfp
MIN_SCALE_MONTE_CARLO = 1e-3
class MCSoftmaxOutputLayerBase(tf.keras.layers.Layer):
"""Base class for MC heteroscesastic output layers.
Collier, M., Mustafa, B., Kokiopoulou, E., Jenatton, R., & Berent, J. (2020).
A Simple Probabilistic Method for Deep Classification under Input-Dependent
Label Noise. arXiv preprint arXiv:2003.06778.
"""
def __init__(
self,
num_classes,
logit_noise=tfp.distributions.Normal,
temperature=1.0,
train_mc_samples=1000,
test_mc_samples=1000,
compute_pred_variance=False,
share_samples_across_batch=False,
logits_only=False,
eps=1e-7,
name="MCSoftmaxOutputLayerBase",
):
"""Creates an instance of MCSoftmaxOutputLayerBase.
Args:
num_classes: Integer. Number of classes for classification task.
logit_noise: tfp.distributions instance. Must be a location-scale
distribution. Valid values: tfp.distributions.Normal,
tfp.distributions.Logistic, tfp.distributions.Gumbel.
temperature: Float or scalar `Tensor` representing the softmax
temperature.
train_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during training.
test_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during testing/inference.
compute_pred_variance: Boolean. Whether to estimate the predictive
variance. If False the __call__ method will output None for the
predictive_variance tensor.
share_samples_across_batch: Boolean. If True, the latent noise samples
are shared across batch elements. If encountering XLA compilation errors
due to dynamic shape inference, setting = True may solve.
logits_only: Boolean. If True, only return the logits from the __call__
method. Useful when a single output Tensor is required e.g.
tf.keras.Sequential models require a single output Tensor.
eps: Float. Clip probabilities into [eps, 1.0] softmax or
[eps, 1.0 - eps] sigmoid before applying log (softmax), or inverse
sigmoid.
name: String. The name of the layer used for name scoping.
Returns:
MCSoftmaxOutputLayerBase instance.
Raises:
ValueError if logit_noise not in tfp.distributions.Normal,
tfp.distributions.Logistic, tfp.distributions.Gumbel.
"""
if logit_noise not in (
tfp.distributions.Normal,
tfp.distributions.Logistic,
tfp.distributions.Gumbel,
):
raise ValueError("logit_noise must be Normal, Logistic or Gumbel")
super(MCSoftmaxOutputLayerBase, self).__init__(name=name)
self._num_classes = num_classes
self._logit_noise = logit_noise
self._temperature = temperature
self._train_mc_samples = train_mc_samples
self._test_mc_samples = test_mc_samples
self._compute_pred_variance = compute_pred_variance
self._share_samples_across_batch = share_samples_across_batch
self._logits_only = logits_only
self._eps = eps
self._name = name
def _compute_noise_samples(self, scale, num_samples, seed):
"""Utility function to compute the samples of the logit noise.
Args:
scale: Tensor of shape
[batch_size, 1 if num_classes == 2 else num_classes].
Scale parameters of the distributions to be sampled.
num_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
Tensor. Logit noise samples of shape: [batch_size, num_samples,
1 if num_classes == 2 else num_classes].
"""
if self._share_samples_across_batch:
num_noise_samples = 1
else:
num_noise_samples = tf.shape(scale)[0]
dist = self._logit_noise(
loc=tf.zeros([num_noise_samples, self._num_classes], dtype=scale.dtype),
scale=tf.ones([num_noise_samples, self._num_classes], dtype=scale.dtype),
)
tf.random.set_seed(seed)
noise_samples = dist.sample(num_samples, seed=seed)
# dist.sample(total_mc_samples) returns Tensor of shape
# [total_mc_samples, batch_size, d], here we reshape to
# [batch_size, total_mc_samples, d]
return tf.transpose(noise_samples, [1, 0, 2]) * tf.expand_dims(scale, 1)
def _compute_mc_samples(self, locs, scale, num_samples, seed):
"""Utility function to compute Monte-Carlo samples (using softmax).
Args:
locs: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Location parameters of the
distributions to be sampled.
scale: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Scale parameters of the
distributions to be sampled.
num_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
Tensor of shape [batch_size, num_samples,
1 if num_classes == 2 else num_classes]. All of the MC samples.
"""
locs = tf.expand_dims(locs, axis=1)
noise_samples = self._compute_noise_samples(scale, num_samples, seed)
latents = locs + noise_samples
if self._num_classes == 2:
return tf.math.sigmoid(latents / self._temperature)
else:
return tf.nn.softmax(latents / self._temperature)
def _compute_predictive_mean(self, locs, scale, total_mc_samples, seed):
"""Utility function to compute the estimated predictive distribution.
Args:
locs: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Location parameters of the
distributions to be sampled.
scale: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Scale parameters of the
distributions to be sampled.
total_mc_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
Tensor of shape [batch_size, 1 if num_classes == 2 else num_classes]
- the mean of the MC samples.
"""
if self._compute_pred_variance and seed is None:
seed = random.randrange(2 ** 63 - 1)
samples = self._compute_mc_samples(locs, scale, total_mc_samples, seed)
return tf.reduce_mean(samples, axis=1)
def _compute_predictive_variance(self, mean, locs, scale, seed, num_samples):
"""Utility function to compute the per class predictive variance.
Args:
mean: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Estimated predictive
distribution.
locs: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Location parameters of the
distributions to be sampled.
scale: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Scale parameters of the
distributions to be sampled.
seed: Python integer for seeding the random number generator.
num_samples: Integer. Number of Monte-Carlo samples to take.
Returns:
Tensor of shape: [batch_size, num_samples,
1 if num_classes == 2 else num_classes]. Estimated predictive variance.
"""
mean = tf.expand_dims(mean, axis=1)
mc_samples = self._compute_mc_samples(locs, scale, num_samples, seed)
total_variance = tf.reduce_mean((mc_samples - mean) ** 2, axis=1)
return total_variance
def _compute_loc_param(self, inputs):
"""Computes location parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tensor of shape [batch_size, num_classes].
"""
return
def _compute_scale_param(self, inputs):
"""Computes scale parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tensor of shape [batch_size, num_classes].
"""
return
def __call__(self, inputs, training=True, seed=None):
"""Computes predictive and log predictive distribution.
Uses Monte Carlo estimate of softmax approximation to heteroscedastic model
to compute predictive distribution. O(mc_samples * num_classes).
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
training: Boolean. Whether we are training or not.
seed: Python integer for seeding the random number generator.
Returns:
Tensor logits if logits_only = True. Otherwise,
tuple of (logits, log_probs, probs, predictive_variance). For multi-class
classification i.e. num_classes > 2 logits = log_probs and logits can be
used with the standard tf.nn.sparse_softmax_cross_entropy_with_logits loss
function. For binary classification i.e. num_classes = 2, logits
represents the argument to a sigmoid function that would yield probs
(logits = inverse_sigmoid(probs)), so logits can be used with the
tf.nn.sigmoid_cross_entropy_with_logits loss function.
Raises:
ValueError if seed is provided but model is running in graph mode.
"""
# Seed shouldn't be provided in graph mode.
if not tf.executing_eagerly():
if seed is not None:
raise ValueError(
"Seed should not be provided when running in graph "
"mode, but %s was provided." % seed
)
with tf.name_scope(self._name):
locs = self._compute_loc_param(
inputs
) # pylint: disable=assignment-from-none
scale = self._compute_scale_param(
inputs
) # pylint: disable=assignment-from-none
if training:
total_mc_samples = self._train_mc_samples
else:
total_mc_samples = self._test_mc_samples
probs_mean = self._compute_predictive_mean(
locs, scale, total_mc_samples, seed
)
pred_variance = None
if self._compute_pred_variance:
pred_variance = self._compute_predictive_variance(
probs_mean, locs, scale, seed, total_mc_samples
)
probs_mean = tf.clip_by_value(probs_mean, self._eps, 1.0)
log_probs = tf.math.log(probs_mean)
if self._num_classes == 2:
# inverse sigmoid
probs_mean = tf.clip_by_value(probs_mean, self._eps, 1.0 - self._eps)
logits = log_probs - tf.math.log(1.0 - probs_mean)
else:
logits = log_probs
if self._logits_only:
return logits
return logits, log_probs, probs_mean, pred_variance
def get_config(self):
config = {
"num_classes": self._num_classes,
"logit_noise": self._logit_noise,
"temperature": self._temperature,
"train_mc_samples": self._train_mc_samples,
"test_mc_samples": self._test_mc_samples,
"compute_pred_variance": self._compute_pred_variance,
"share_samples_across_batch": self._share_samples_across_batch,
"logits_only": self._logits_only,
"name": self._name,
}
new_config = super().get_config()
new_config.update(config)
return new_config
class MCSoftmaxDense(MCSoftmaxOutputLayerBase):
"""Monte Carlo estimation of softmax approx to heteroscedastic predictions."""
def __init__(
self,
num_classes,
logit_noise=tfp.distributions.Normal,
temperature=1.0,
train_mc_samples=1000,
test_mc_samples=1000,
compute_pred_variance=False,
share_samples_across_batch=False,
logits_only=False,
eps=1e-7,
dtype=None,
kernel_regularizer=None,
bias_regularizer=None,
name="MCSoftmaxDense",
):
"""Creates an instance of MCSoftmaxDense.
This is a MC softmax heteroscedastic drop in replacement for a
tf.keras.layers.Dense output layer. e.g. simply change:
```python
logits = tf.keras.layers.Dense(...)(x)
```
to
```python
logits = MCSoftmaxDense(...)(x)[0]
```
Args:
num_classes: Integer. Number of classes for classification task.
logit_noise: tfp.distributions instance. Must be a location-scale
distribution. Valid values: tfp.distributions.Normal,
tfp.distributions.Logistic, tfp.distributions.Gumbel.
temperature: Float or scalar `Tensor` representing the softmax
temperature.
train_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during training.
test_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during testing/inference.
compute_pred_variance: Boolean. Whether to estimate the predictive
variance. If False the __call__ method will output None for the
predictive_variance tensor.
share_samples_across_batch: Boolean. If True, the latent noise samples
are shared across batch elements. If encountering XLA compilation errors
due to dynamic shape inference setting = True may solve.
logits_only: Boolean. If True, only return the logits from the __call__
method. Set True to serialize tf.keras.Sequential models.
eps: Float. Clip probabilities into [eps, 1.0] before applying log.
dtype: Tensorflow dtype. The dtype of output Tensor and weights associated
with the layer.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
bias_regularizer: Regularizer function applied to the bias vector.
name: String. The name of the layer used for name scoping.
Returns:
MCSoftmaxDense instance.
Raises:
ValueError if logit_noise not in tfp.distributions.Normal,
tfp.distributions.Logistic, tfp.distributions.Gumbel.
"""
assert num_classes >= 2
super(MCSoftmaxDense, self).__init__(
num_classes,
logit_noise=logit_noise,
temperature=temperature,
train_mc_samples=train_mc_samples,
test_mc_samples=test_mc_samples,
compute_pred_variance=compute_pred_variance,
share_samples_across_batch=share_samples_across_batch,
logits_only=logits_only,
eps=eps,
name=name,
)
self._loc_layer = tf.keras.layers.Dense(
1 if num_classes == 2 else num_classes,
activation=None,
kernel_regularizer=kernel_regularizer,
name="loc_layer",
dtype=dtype,
bias_regularizer=bias_regularizer,
)
self._scale_layer = tf.keras.layers.Dense(
1 if num_classes == 2 else num_classes,
activation=tf.math.softplus,
name="scale_layer",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
def _compute_loc_param(self, inputs):
"""Computes location parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tensor of shape [batch_size, num_classes].
"""
return self._loc_layer(inputs)
def _compute_scale_param(self, inputs):
"""Computes scale parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tensor of shape [batch_size, num_classes].
"""
return self._scale_layer(inputs) + MIN_SCALE_MONTE_CARLO
def get_config(self):
config = {
"loc_layer": tf.keras.layers.serialize(self._loc_layer),
"scale_layer": tf.keras.layers.serialize(self._scale_layer),
}
new_config = super().get_config()
new_config.update(config)
return new_config
class MCSoftmaxDenseFA(MCSoftmaxOutputLayerBase):
"""Softmax and factor analysis approx to heteroscedastic predictions."""
def __init__(
self,
num_classes,
num_factors,
temperature=1.0,
parameter_efficient=False,
train_mc_samples=1000,
test_mc_samples=1000,
compute_pred_variance=False,
share_samples_across_batch=False,
logits_only=False,
eps=1e-7,
dtype=None,
kernel_regularizer=None,
bias_regularizer=None,
name="MCSoftmaxDenseFA",
):
"""Creates an instance of MCSoftmaxDenseFA.
if we assume:
```
u ~ N(mu(x), sigma(x))
y = softmax(u / temperature)
```
we can do a low rank approximation of sigma(x) the full rank matrix as:
```
eps_R ~ N(0, I_R), eps_K ~ N(0, I_K)
u = mu(x) + matmul(V(x), eps_R) + d(x) * eps_K
```
where V(x) is a matrix of dimension [num_classes, R=num_factors]
and d(x) is a vector of dimension [num_classes, 1]
num_factors << num_classes => approx to sampling ~ N(mu(x), sigma(x))
This is a MC softmax heteroscedastic drop in replacement for a
tf.keras.layers.Dense output layer. e.g. simply change:
```python
logits = tf.keras.layers.Dense(...)(x)
```
to
```python
logits = MCSoftmaxDenseFA(...)(x)[0]
```
Args:
num_classes: Integer. Number of classes for classification task.
num_factors: Integer. Number of factors to use in approximation to full
rank covariance matrix.
temperature: Float or scalar `Tensor` representing the softmax
temperature.
parameter_efficient: Boolean. Whether to use the parameter efficient
version of the method. If True then samples from the latent distribution
are generated as: mu(x) + v(x) * matmul(V, eps_R) + diag(d(x), eps_K)),
where eps_R ~ N(0, I_R), eps_K ~ N(0, I_K). If false then latent samples
are generated as: mu(x) + matmul(V(x), eps_R) + diag(d(x), eps_K)).
Computing V(x) as function of x increases the number of parameters
introduced by the method.
train_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during training.
test_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during testing/inference.
compute_pred_variance: Boolean. Whether to estimate the predictive
variance. If False the __call__ method will output None for the
predictive_variance tensor.
share_samples_across_batch: Boolean. If True, the latent noise samples
are shared across batch elements. If encountering XLA compilation errors
due to dynamic shape inference setting = True may solve.
logits_only: Boolean. If True, only return the logits from the __call__
method. Set True to serialize tf.keras.Sequential models.
eps: Float. Clip probabilities into [eps, 1.0] before applying log.
dtype: Tensorflow dtype. The dtype of output Tensor and weights associated
with the layer.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
bias_regularizer: Regularizer function applied to the bias vector.
name: String. The name of the layer used for name scoping.
Returns:
MCSoftmaxDenseFA instance.
"""
# no need to model correlations between classes in binary case
assert num_classes > 2
assert num_factors <= num_classes
super(MCSoftmaxDenseFA, self).__init__(
num_classes,
logit_noise=tfp.distributions.Normal,
temperature=temperature,
train_mc_samples=train_mc_samples,
test_mc_samples=test_mc_samples,
compute_pred_variance=compute_pred_variance,
share_samples_across_batch=share_samples_across_batch,
logits_only=logits_only,
eps=eps,
name=name,
)
self._num_factors = num_factors
self._parameter_efficient = parameter_efficient
if parameter_efficient:
self._scale_layer_homoscedastic = tf.keras.layers.Dense(
num_classes,
name=name + "_scale_layer_homoscedastic",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
self._scale_layer_heteroscedastic = tf.keras.layers.Dense(
num_classes,
name=name + "_scale_layer_heteroscedastic",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
else:
self._scale_layer = tf.keras.layers.Dense(
num_classes * num_factors,
name=name + "_scale_layer",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
self._loc_layer = tf.keras.layers.Dense(
num_classes,
name=name + "_loc_layer",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
self._diag_layer = tf.keras.layers.Dense(
num_classes,
activation=tf.math.softplus,
name=name + "_diag_layer",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
def _compute_loc_param(self, inputs):
"""Computes location parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tensor of shape [batch_size, num_classes].
"""
return self._loc_layer(inputs)
def _compute_scale_param(self, inputs):
"""Computes scale parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tuple of tensors of shape ([batch_size, num_classes * num_factors],
[batch_size, num_classes]).
"""
if self._parameter_efficient:
return (inputs, self._diag_layer(inputs) + MIN_SCALE_MONTE_CARLO)
else:
return (
self._scale_layer(inputs),
self._diag_layer(inputs) + MIN_SCALE_MONTE_CARLO,
)
def _compute_diagonal_noise_samples(self, diag_scale, num_samples, seed):
"""Compute samples of the diagonal elements logit noise.
Args:
diag_scale: `Tensor` of shape [batch_size, num_classes]. Diagonal
elements of scale parameters of the distribution to be sampled.
num_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
`Tensor`. Logit noise samples of shape: [batch_size, num_samples,
1 if num_classes == 2 else num_classes].
"""
if self._share_samples_across_batch:
num_noise_samples = 1
else:
num_noise_samples = tf.shape(diag_scale)[0]
dist = tfp.distributions.Normal(
loc=tf.zeros(
[num_noise_samples, self._num_classes], dtype=diag_scale.dtype
),
scale=tf.ones(
[num_noise_samples, self._num_classes], dtype=diag_scale.dtype
),
)
tf.random.set_seed(seed)
diag_noise_samples = dist.sample(num_samples, seed=seed)
# dist.sample(total_mc_samples) returns Tensor of shape
# [total_mc_samples, batch_size, d], here we reshape to
# [batch_size, total_mc_samples, d]
diag_noise_samples = tf.transpose(diag_noise_samples, [1, 0, 2])
return diag_noise_samples * tf.expand_dims(diag_scale, 1)
def _compute_standard_normal_samples(self, factor_loadings, num_samples, seed):
"""Utility function to compute samples from a standard normal distribution.
Args:
factor_loadings: `Tensor` of shape
[batch_size, num_classes * num_factors]. Factor loadings for scale
parameters of the distribution to be sampled.
num_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
`Tensor`. Samples of shape: [batch_size, num_samples, num_factors].
"""
if self._share_samples_across_batch:
num_noise_samples = 1
else:
num_noise_samples = tf.shape(factor_loadings)[0]
dist = tfp.distributions.Normal(
loc=tf.zeros(
[num_noise_samples, self._num_factors], dtype=factor_loadings.dtype
),
scale=tf.ones(
[num_noise_samples, self._num_factors], dtype=factor_loadings.dtype
),
)
tf.random.set_seed(seed)
standard_normal_samples = dist.sample(num_samples, seed=seed)
# dist.sample(total_mc_samples) returns Tensor of shape
# [total_mc_samples, batch_size, d], here we reshape to
# [batch_size, total_mc_samples, d]
standard_normal_samples = tf.transpose(standard_normal_samples, [1, 0, 2])
if self._share_samples_across_batch:
standard_normal_samples = tf.tile(
standard_normal_samples, [tf.shape(factor_loadings)[0], 1, 1]
)
return standard_normal_samples
def _compute_noise_samples(self, scale, num_samples, seed):
"""Utility function to compute the samples of the logit noise.
Args:
scale: Tuple of tensors of shape (
[batch_size, num_classes * num_factors],
[batch_size, num_classes]). Factor loadings and diagonal elements
for scale parameters of the distribution to be sampled.
num_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
`Tensor`. Logit noise samples of shape: [batch_size, num_samples,
1 if num_classes == 2 else num_classes].
"""
factor_loadings, diag_scale = scale
# Compute the diagonal noise
diag_noise_samples = self._compute_diagonal_noise_samples(
diag_scale, num_samples, seed
)
# Now compute the factors
standard_normal_samples = self._compute_standard_normal_samples(
factor_loadings, num_samples, seed
)
if self._parameter_efficient:
res = self._scale_layer_homoscedastic(standard_normal_samples)
res *= tf.expand_dims(self._scale_layer_heteroscedastic(factor_loadings), 1)
else:
# reshape scale vector into factor loadings matrix
factor_loadings = tf.reshape(
factor_loadings, [-1, self._num_classes, self._num_factors]
)
# transform standard normal into ~ full rank covariance Gaussian samples
res = tf.einsum("ijk,iak->iaj", factor_loadings, standard_normal_samples)
return res + diag_noise_samples
def get_config(self):
config = {
"loc_layer": self._loc_layer.get_config(),
"diag_layer": self._diag_layer.get_config(),
"num_factors": self._num_factors,
"parameter_efficient": self._parameter_efficient,
}
if self._parameter_efficient:
config["scale_layer_homoscedastic"] = tf.keras.layers.serialize(
self._scale_layer_homoscedastic
)
config["scale_layer_heteroscedastic"] = tf.keras.layers.serialize(
self._scale_layer_heteroscedastic
)
else:
config["scale_layer"] = tf.keras.layers.serialize(self._scale_layer)
new_config = super().get_config()
new_config.update(config)
return new_config
class MultiHeadMCSoftmaxDenseFA(MCSoftmaxOutputLayerBase):
"""Softmax and factor analysis approx to heteroscedastic predictions.
Multi Head variation where the output is composed by multiple (ensemble size)
output predictions, with a shared latent space between ensembles.
"""
def __init__(
self,
num_classes,
num_factors,
ensemble_size,
temperature=1.0,
parameter_efficient=False,
train_mc_samples=1000,
test_mc_samples=1000,
compute_pred_variance=False,
share_samples_across_batch=False,
logits_only=False,
eps=1e-7,
dtype=None,
kernel_regularizer=None,
bias_regularizer=None,
name="MultiHeadMCSoftmaxDenseFA",
):
"""Creates an instance of MultiHeadMCSoftmaxDenseFA.
if we assume:
```
u ~ N(mu(x), sigma(x))
where x is [x1, x2, ... xn], with n = ensemble_size
y = [softmax(u_i / temperature)
for u_i in u.reshape(ensemble_size, num_classes)]
```
we can do a low rank approximation of sigma(x) the full rank matrix as:
```
eps_R ~ N(0, I_R), eps_K ~ N(0, I_K)
u = mu(x) + matmul(V(x), eps_R) + d(x) * eps_K
```
where V(x) is a matrix of dimension
[num_classes * ensemble_size, R=num_factors]
and d(x) is a vector of dimension [num_classes * ensemble_size, 1]
num_factors << num_classes * ensemble_size => approx to sampling
~ N(mu(x), sigma(x))
This is a MC softmax heteroscedastic drop in replacement for a
tf.keras.layers.Dense output layer. e.g. simply change:
```python
logits = tf.keras.layers.Dense(...)(x)
```
to
```python
logits = MultiHeadMCSoftmaxDenseFA(...)(x)[0]
```
Args:
num_classes: Integer. Number of classes for classification task.
num_factors: Integer. Number of factors to use in approximation to full
rank covariance matrix.
ensemble_size: Integer. Size of ensemble.
temperature: Float or scalar `Tensor` representing the softmax
temperature.
parameter_efficient: Boolean. Whether to use the parameter efficient
version of the method. If True then samples from the latent distribution
are generated as: mu(x) + v(x) * matmul(V, eps_R) + diag(d(x), eps_K)),
where eps_R ~ N(0, I_R), eps_K ~ N(0, I_K). If false then latent samples
are generated as: mu(x) + matmul(V(x), eps_R) + diag(d(x), eps_K)).
Computing V(x) as function of x increases the number of parameters
introduced by the method.
train_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during training.
test_mc_samples: The number of Monte-Carlo samples used to estimate the
predictive distribution during testing/inference.
compute_pred_variance: Boolean. Whether to estimate the predictive
variance. If False the __call__ method will output None for the
predictive_variance tensor.
share_samples_across_batch: Boolean. If True, the latent noise samples
are shared across batch elements. If encountering XLA compilation errors
due to dynamic shape inference setting = True may solve.
logits_only: Boolean. If True, only return the logits from the __call__
method. Set True to serialize tf.keras.Sequential models.
eps: Float. Clip probabilities into [eps, 1.0] before applying log.
dtype: Tensorflow dtype. The dtype of output Tensor and weights associated
with the layer.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
bias_regularizer: Regularizer function applied to the bias vector.
name: String. The name of the layer used for name scoping.
Returns:
MultiHeadMCSoftmaxDenseFA instance.
"""
# no need to model correlations between classes in binary case
assert num_classes > 2
assert num_factors <= num_classes
super(MultiHeadMCSoftmaxDenseFA, self).__init__(
num_classes,
logit_noise=tfp.distributions.Normal,
temperature=temperature,
train_mc_samples=train_mc_samples,
test_mc_samples=test_mc_samples,
compute_pred_variance=compute_pred_variance,
share_samples_across_batch=share_samples_across_batch,
logits_only=logits_only,
eps=eps,
name=name,
)
self._num_factors = num_factors
self._parameter_efficient = parameter_efficient
self._ensemble_size = ensemble_size
if parameter_efficient:
self._scale_layer_homoscedastic = tf.keras.layers.Dense(
num_classes * ensemble_size,
name="scale_layer_homoscedastic",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
self._scale_layer_heteroscedastic = tf.keras.layers.Dense(
num_classes * ensemble_size,
name="scale_layer_heteroscedastic",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
else:
self._scale_layer = tf.keras.layers.Dense(
num_classes * ensemble_size * num_factors,
name="scale_layer",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
self._loc_layer = tf.keras.layers.Dense(
num_classes * ensemble_size,
name="loc_layer",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
self._diag_layer = tf.keras.layers.Dense(
num_classes * ensemble_size,
activation=tf.math.softplus,
name="diag_layer",
dtype=dtype,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
)
def _compute_mc_samples(self, locs, scale, num_samples, seed):
"""Utility function to compute Monte-Carlo samples (using softmax).
Args:
locs: Tensor of shape [batch_size, total_mc_samples * ensemble_size,
1 if num_classes == 2 else num_classes]. Location parameters of the
distributions to be sampled.
scale: Tensor of shape [batch_size, total_mc_samples,
1 if num_classes == 2 else num_classes]. Scale parameters of the
distributions to be sampled.
num_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
Tensor of shape [batch_size, num_samples,
1 if num_classes == 2 else num_classes]. All of the MC samples.
"""
locs = tf.expand_dims(locs, axis=1)
noise_samples = self._compute_noise_samples(scale, num_samples, seed)
latents = locs + noise_samples
latents = tf.keras.layers.Reshape(
[num_samples, self._ensemble_size, self._num_classes]
)(latents)
if self._num_classes == 2:
return tf.math.sigmoid(latents / self._temperature)
else:
return tf.nn.softmax(latents / self._temperature)
def _compute_loc_param(self, inputs):
"""Computes location parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tensor of shape [batch_size, num_classes].
"""
return self._loc_layer(inputs)
def _compute_scale_param(self, inputs):
"""Computes scale parameter of the "logits distribution".
Args:
inputs: Tensor. The input to the heteroscedastic output layer.
Returns:
Tuple of tensors of shape ([batch_size, num_classes * num_factors],
[batch_size, num_classes]).
"""
if self._parameter_efficient:
return (inputs, self._diag_layer(inputs) + MIN_SCALE_MONTE_CARLO)
else:
return (
self._scale_layer(inputs),
self._diag_layer(inputs) + MIN_SCALE_MONTE_CARLO,
)
def _compute_diagonal_noise_samples(self, diag_scale, num_samples, seed):
"""Compute samples of the diagonal elements logit noise.
Args:
diag_scale: `Tensor` of shape [batch_size, num_classes * ensemble_size].
Diagonal elements of scale parameters of the distribution to be sampled.
num_samples: Integer. Number of Monte-Carlo samples to take.
seed: Python integer for seeding the random number generator.
Returns:
`Tensor`. Logit noise samples of shape: [batch_size, num_samples,
1 if num_classes == 2 else num_classes].
"""
if self._share_samples_across_batch:
num_noise_samples = 1
else:
num_noise_samples = tf.shape(diag_scale)[0]
dist = tfp.distributions.Normal(
loc=tf.zeros(
[num_noise_samples, self._num_classes * self._ensemble_size],