-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathMLP_F.py
861 lines (746 loc) · 34.6 KB
/
MLP_F.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
"""
# Modification of hypnettorch file
# https://hypnettorch.readthedocs.io/en/latest/_modules/hypnettorch/mnets/mlp.html#MLP# licensed under the Apache License, Version 2.0
# HyperMask with FeCAM needed some modifications due to feature extractions.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from hypnettorch.mnets.mnet_interface import MainNetInterface
from hypnettorch.utils.batchnorm_layer import BatchNormLayer
from hypnettorch.utils.context_mod_layer import ContextModLayer
from hypnettorch.utils.torch_utils import init_params
from copy import deepcopy
class MLPFeCAM(nn.Module, MainNetInterface):
"""Implementation of a Multi-Layer Perceptron (MLP).
This is a simple fully-connected network, that receives input vector
:math:`\mathbf{x}` and outputs a vector :math:`\mathbf{y}` of real values.
The output mapping does not include a non-linearity by default, as we wanna
map to the whole real line (but see argument ``out_fn``).
Args:
n_in (int): Number of inputs.
n_out (int): Number of outputs.
hidden_layers (list or tuple): A list of integers, each number denoting
the size of a hidden layer.
activation_fn: The nonlinearity used in hidden layers. If ``None``, no
nonlinearity will be applied.
use_bias (bool): Whether layers may have bias terms.
no_weights (bool): If set to ``True``, no trainable parameters will be
constructed, i.e., weights are assumed to be produced ad-hoc
by a hypernetwork and passed to the :meth:`forward` method.
init_weights (optional): This option is for convinience reasons.
The option expects a list of parameter values that are used to
initialize the network weights. As such, it provides a
convinient way of initializing a network with a weight draw
produced by the hypernetwork.
Note, internal weights (see
:attr:`mnets.mnet_interface.MainNetInterface.weights`) will be
affected by this argument only.
dropout_rate: If ``-1``, no dropout will be applied. Otherwise a number
between 0 and 1 is expected, denoting the dropout rate of hidden
layers.
use_spectral_norm: Use spectral normalization for training.
use_batch_norm (bool): Whether batch normalization should be used. Will
be applied before the activation function in all hidden layers.
bn_track_stats (bool): If batch normalization is used, then this option
determines whether running statistics are tracked in these
layers or not (see argument ``track_running_stats`` of class
:class:`utils.batchnorm_layer.BatchNormLayer`).
If ``False``, then batch statistics are utilized even during
evaluation. If ``True``, then running stats are tracked. When
using this network in a continual learning scenario with
different tasks then the running statistics are expected to be
maintained externally. The argument ``stats_id`` of the method
:meth:`utils.batchnorm_layer.BatchNormLayer.forward` can be
provided using the argument ``condition`` of method :meth:`forward`.
Example:
To maintain the running stats, one can simply iterate over
all batch norm layers and checkpoint the current running
stats (e.g., after learning a task when applying a Continual
learning scenario).
.. code:: python
for bn_layer in net.batchnorm_layers:
bn_layer.checkpoint_stats()
distill_bn_stats (bool): If ``True``, then the shapes of the batchnorm
statistics will be added to the attribute
:attr:`mnets.mnet_interface.MainNetInterface.\
hyper_shapes_distilled` and the current statistics will be returned by the
method :meth:`distillation_targets`.
Note, this attribute may only be ``True`` if ``bn_track_stats``
is ``True``.
use_context_mod (bool): Add context-dependent modulation layers
:class:`utils.context_mod_layer.ContextModLayer` after the linear
computation of each layer.
context_mod_inputs (bool): Whether context-dependent modulation should
also be applied to network intpus directly. I.e., assume
:math:`\mathbf{x}` is the input to the network. Then the first
network operation would be to modify the input via
:math:`\mathbf{x} \cdot \mathbf{g} + \mathbf{s}` using context-
dependent gain and shift parameters.
Note:
Argument applies only if ``use_context_mod`` is ``True``.
no_last_layer_context_mod (bool): If ``True``, context-dependent
modulation will not be applied to the output layer.
Note:
Argument applies only if ``use_context_mod`` is ``True``.
context_mod_no_weights (bool): The weights of the context-mod layers
(:class:`utils.context_mod_layer.ContextModLayer`) are treated
independently of the option ``no_weights``.
This argument can be used to decide whether the context-mod
parameters (gains and shifts) are maintained internally or
externally.
Note:
Check out argument ``weights`` of the :meth:`forward` method
on how to correctly pass weights to the network that are
externally maintained.
context_mod_post_activation (bool): Apply context-mod layers after the
activation function (``activation_fn``) in hidden layer rather than
before, which is the default behavior.
Note:
This option only applies if ``use_context_mod`` is ``True``.
Note:
This option does not affect argument ``context_mod_inputs``.
Note:
This option does not affect argument
``no_last_layer_context_mod``. Hence, if a output-nonlinearity
is applied through argument ``out_fn``, then context-modulation
would be applied before this non-linearity.
context_mod_gain_offset (bool): Activates option ``apply_gain_offset``
of class :class:`utils.context_mod_layer.ContextModLayer` for all
context-mod layers that will be instantiated.
context_mod_gain_softplus (bool): Activates option
``apply_gain_softplus`` of class
:class:`utils.context_mod_layer.ContextModLayer` for all
context-mod layers that will be instantiated.
out_fn (optional): If provided, this function will be applied to the
output neurons of the network.
Warning:
This changes the interpretation of the output of the
:meth:`forward` method.
verbose (bool): Whether to print information (e.g., the number of
weights) during the construction of the network.
"""
def __init__(
self,
n_in=1,
n_out=1,
hidden_layers=(10, 10),
activation_fn=torch.nn.ReLU(),
use_bias=True,
no_weights=False,
init_weights=None,
dropout_rate=-1,
use_spectral_norm=False,
use_batch_norm=False,
bn_track_stats=True,
distill_bn_stats=False,
use_context_mod=False,
context_mod_inputs=False,
no_last_layer_context_mod=False,
context_mod_no_weights=False,
context_mod_post_activation=False,
context_mod_gain_offset=False,
context_mod_gain_softplus=False,
out_fn=None,
verbose=True,
):
# FIXME find a way using super to handle multiple inheritance.
nn.Module.__init__(self)
MainNetInterface.__init__(self)
# FIXME Spectral norm is incorrectly implemented. Function
# `nn.utils.spectral_norm` needs to be called in the constructor, such
# that sepc norm is wrapped around a module.
if use_spectral_norm:
raise NotImplementedError(
"Spectral normalization not yet "
+ "implemented for this network."
)
if use_batch_norm and use_context_mod:
# FIXME Does it make sense to have both enabled?
# I.e., should we produce a warning or error?
pass
# Tuple are not mutable.
hidden_layers = list(hidden_layers)
self._a_fun = activation_fn
assert init_weights is None or (
not no_weights or not context_mod_no_weights
)
self._no_weights = no_weights
self._dropout_rate = dropout_rate
# self._use_spectral_norm = use_spectral_norm
self._use_batch_norm = use_batch_norm
self._bn_track_stats = bn_track_stats
self._distill_bn_stats = distill_bn_stats and use_batch_norm
self._use_context_mod = use_context_mod
self._context_mod_inputs = context_mod_inputs
self._no_last_layer_context_mod = no_last_layer_context_mod
self._context_mod_no_weights = context_mod_no_weights
self._context_mod_post_activation = context_mod_post_activation
self._context_mod_gain_offset = context_mod_gain_offset
self._context_mod_gain_softplus = context_mod_gain_softplus
self._out_fn = out_fn
self._has_bias = use_bias
self._has_fc_out = True
# We need to make sure that the last 2 entries of `weights` correspond
# to the weight matrix and bias vector of the last layer.
self._mask_fc_out = True
self._has_linear_out = True if out_fn is None else False
if use_spectral_norm and no_weights:
raise ValueError(
"Cannot use spectral norm in a network without " + "parameters."
)
# FIXME make sure that this implementation is correct in all situations
# (e.g., what to do if weights are passed to the forward method?).
if use_spectral_norm:
self._spec_norm = nn.utils.spectral_norm
else:
self._spec_norm = lambda x: x # identity
self._param_shapes = []
self._param_shapes_meta = []
self._weights = (
None
if no_weights and context_mod_no_weights
else nn.ParameterList()
)
self._hyper_shapes_learned = (
None if not no_weights and not context_mod_no_weights else []
)
self._hyper_shapes_learned_ref = (
None if self._hyper_shapes_learned is None else []
)
if dropout_rate != -1:
assert dropout_rate >= 0.0 and dropout_rate <= 1.0
self._dropout = nn.Dropout(p=dropout_rate)
### Define and initialize context mod weights.
self._context_mod_layers = nn.ModuleList() if use_context_mod else None
self._context_mod_shapes = [] if use_context_mod else None
if use_context_mod:
cm_ind = 0
cm_sizes = []
if context_mod_inputs:
cm_sizes.append(n_in)
cm_sizes.extend(hidden_layers)
if not no_last_layer_context_mod:
cm_sizes.append(n_out)
for i, n in enumerate(cm_sizes):
cmod_layer = ContextModLayer(
n,
no_weights=context_mod_no_weights,
apply_gain_offset=context_mod_gain_offset,
apply_gain_softplus=context_mod_gain_softplus,
)
self._context_mod_layers.append(cmod_layer)
self.param_shapes.extend(cmod_layer.param_shapes)
assert len(cmod_layer.param_shapes) == 2
self._param_shapes_meta.extend(
[
{
"name": "cm_scale",
"index": -1
if context_mod_no_weights
else len(self._weights),
"layer": -1,
}, # 'layer' is set later.
{
"name": "cm_shift",
"index": -1
if context_mod_no_weights
else len(self._weights) + 1,
"layer": -1,
}, # 'layer' is set later.
]
)
self._context_mod_shapes.extend(cmod_layer.param_shapes)
if context_mod_no_weights:
self._hyper_shapes_learned.extend(cmod_layer.param_shapes)
else:
self._weights.extend(cmod_layer.weights)
# FIXME ugly code. Move initialization somewhere else.
if not context_mod_no_weights and init_weights is not None:
assert len(cmod_layer.weights) == 2
for ii in range(2):
assert np.all(
np.equal(
list(init_weights[cm_ind].shape),
list(cm_ind.weights[ii].shape),
)
)
cmod_layer.weights[ii].data = init_weights[cm_ind]
cm_ind += 1
if init_weights is not None:
init_weights = init_weights[cm_ind:]
if context_mod_no_weights:
self._hyper_shapes_learned_ref = list(
range(len(self._param_shapes))
)
### Define and initialize batch norm weights.
self._batchnorm_layers = nn.ModuleList() if use_batch_norm else None
if use_batch_norm:
if distill_bn_stats:
self._hyper_shapes_distilled = []
bn_ind = 0
for i, n in enumerate(hidden_layers):
bn_layer = BatchNormLayer(
n, affine=not no_weights, track_running_stats=bn_track_stats
)
self._batchnorm_layers.append(bn_layer)
self._param_shapes.extend(bn_layer.param_shapes)
assert len(bn_layer.param_shapes) == 2
self._param_shapes_meta.extend(
[
{
"name": "bn_scale",
"index": -1 if no_weights else len(self._weights),
"layer": -1,
}, # 'layer' is set later.
{
"name": "bn_shift",
"index": -1
if no_weights
else len(self._weights) + 1,
"layer": -1,
}, # 'layer' is set later.
]
)
if no_weights:
self._hyper_shapes_learned.extend(bn_layer.param_shapes)
else:
self._weights.extend(bn_layer.weights)
if distill_bn_stats:
self._hyper_shapes_distilled.extend(
[list(p.shape) for p in bn_layer.get_stats(0)]
)
# FIXME ugly code. Move initialization somewhere else.
if not no_weights and init_weights is not None:
assert len(bn_layer.weights) == 2
for ii in range(2):
assert np.all(
np.equal(
list(init_weights[bn_ind].shape),
list(bn_layer.weights[ii].shape),
)
)
bn_layer.weights[ii].data = init_weights[bn_ind]
bn_ind += 1
if init_weights is not None:
init_weights = init_weights[bn_ind:]
### Compute shapes of linear layers.
linear_shapes = MLPFeCAM.weight_shapes(
n_in=n_in,
n_out=n_out,
hidden_layers=hidden_layers,
use_bias=use_bias,
)
self._param_shapes.extend(linear_shapes)
for i, s in enumerate(linear_shapes):
self._param_shapes_meta.append(
{
"name": "weight" if len(s) != 1 else "bias",
"index": -1 if no_weights else len(self._weights) + i,
"layer": -1, # 'layer' is set later.
}
)
num_weights = MainNetInterface.shapes_to_num_weights(self._param_shapes)
### Set missing meta information of param_shapes.
offset = 1 if use_context_mod and context_mod_inputs else 0
shift = 1
if use_batch_norm:
shift += 1
if use_context_mod:
shift += 1
cm_offset = 2 if context_mod_post_activation else 1
bn_offset = 1 if context_mod_post_activation else 2
cm_ind = 0
bn_ind = 0
layer_ind = 0
for i, dd in enumerate(self._param_shapes_meta):
if dd["name"].startswith("cm"):
if offset == 1 and i in [0, 1]:
dd["layer"] = 0
else:
if cm_ind < len(hidden_layers):
dd["layer"] = offset + cm_ind * shift + cm_offset
else:
assert (
cm_ind == len(hidden_layers)
and not no_last_layer_context_mod
)
# No batchnorm in output layer.
dd["layer"] = offset + cm_ind * shift + 1
if dd["name"] == "cm_shift":
cm_ind += 1
elif dd["name"].startswith("bn"):
dd["layer"] = offset + bn_ind * shift + bn_offset
if dd["name"] == "bn_shift":
bn_ind += 1
else:
dd["layer"] = offset + layer_ind * shift
if not use_bias or dd["name"] == "bias":
layer_ind += 1
### Uer information
if verbose:
if use_context_mod:
cm_num_weights = 0
for cm_layer in self._context_mod_layers:
cm_num_weights += MainNetInterface.shapes_to_num_weights(
cm_layer.param_shapes
)
print(
"Creating an MLP with %d weights" % num_weights
+ (
" (including %d weights associated with-" % cm_num_weights
+ "context modulation)"
if use_context_mod
else ""
)
+ "."
+ (" The network uses dropout." if dropout_rate != -1 else "")
+ (" The network uses batchnorm." if use_batch_norm else "")
)
self._layer_weight_tensors = nn.ParameterList()
self._layer_bias_vectors = nn.ParameterList()
if no_weights:
self._hyper_shapes_learned.extend(linear_shapes)
if use_context_mod:
if context_mod_no_weights:
self._hyper_shapes_learned_ref = list(
range(len(self._param_shapes))
)
else:
ncm = len(self._context_mod_shapes)
self._hyper_shapes_learned_ref = list(
range(ncm, len(self._param_shapes))
)
self._is_properly_setup()
return
### Define and initialize linear weights.
for i, dims in enumerate(linear_shapes):
self._weights.append(
nn.Parameter(torch.Tensor(*dims), requires_grad=True)
)
if len(dims) == 1:
self._layer_bias_vectors.append(self._weights[-1])
else:
self._layer_weight_tensors.append(self._weights[-1])
if init_weights is not None:
assert len(init_weights) == len(linear_shapes)
for i in range(len(init_weights)):
assert np.all(
np.equal(list(init_weights[i].shape), linear_shapes[i])
)
if use_bias:
if i % 2 == 0:
self._layer_weight_tensors[i // 2].data = init_weights[
i
]
else:
self._layer_bias_vectors[i // 2].data = init_weights[i]
else:
self._layer_weight_tensors[i].data = init_weights[i]
else:
for i in range(len(self._layer_weight_tensors)):
if use_bias:
init_params(
self._layer_weight_tensors[i],
self._layer_bias_vectors[i],
)
else:
init_params(self._layer_weight_tensors[i])
if self._num_context_mod_shapes() == 0:
# Note, that might be the case if no hidden layers exist and no
# input or output modulation is used.
self._use_context_mod = False
self._is_properly_setup()
def forward(self, x, weights=None, distilled_params=None, condition=None):
"""Compute the output :math:`y` of this network given the input
:math:`x`.
Args:
(....): See docstring of method
:meth:`mnets.mnet_interface.MainNetInterface.forward`. We
provide some more specific information below.
weights (list or dict): If a list of parameter tensors is given and
context modulation is used (see argument ``use_context_mod`` in
constructor), then these parameters are interpreted as context-
modulation parameters if the length of ``weights`` equals
:code:`2*len(net.context_mod_layers)`. Otherwise, the length is
expected to be equal to the length of the attribute
:attr:`mnets.mnet_interface.MainNetInterface.param_shapes`.
Alternatively, a dictionary can be passed with the possible
keywords ``internal_weights`` and ``mod_weights``. Each keyword
is expected to map onto a list of tensors.
The keyword ``internal_weights`` refers to all weights of this
network except for the weights of the context-modulation layers.
The keyword ``mod_weights``, on the other hand, refers
specifically to the weights of the context-modulation layers.
It is not necessary to specify both keywords.
distilled_params: Will be passed as ``running_mean`` and
``running_var`` arguments of method
:meth:`utils.batchnorm_layer.BatchNormLayer.forward` if
batch normalization is used.
condition (int or dict, optional): If ``int`` is provided, then this
argument will be passed as argument ``stats_id`` to the method
:meth:`utils.batchnorm_layer.BatchNormLayer.forward` if
batch normalization is used.
If a ``dict`` is provided instead, the following keywords are
allowed:
- ``bn_stats_id``: Will be handled as ``stats_id`` of the
batchnorm layers as described above.
- ``cmod_ckpt_id``: Will be passed as argument ``ckpt_id``
to the method
:meth:`utils.context_mod_layer.ContextModLayer.forward`.
Returns:
(tuple): Tuple containing:
- **y**: The output of the network.
- **h_y** (optional): If ``out_fn`` was specified in the
constructor, then this value will be returned. It is the last
hidden activation (before the ``out_fn`` has been applied).
- **features** (optional): For using FeCAM for class selection
it is necessary to return values from the last hidden layer
with features extracted by the linear layers.
"""
if (
(not self._use_context_mod and self._no_weights)
or (self._no_weights or self._context_mod_no_weights)
) and weights is None:
raise Exception(
"Network was generated without weights. "
+ 'Hence, "weights" option may not be None.'
)
############################################
### Extract which weights should be used ###
############################################
# I.e., are we using internally maintained weights or externally given
# ones or are we even mixing between these groups.
n_cm = self._num_context_mod_shapes()
if weights is None:
weights = self.weights
if self._use_context_mod:
cm_weights = weights[:n_cm]
int_weights = weights[n_cm:]
else:
int_weights = weights
else:
int_weights = None
cm_weights = None
if isinstance(weights, dict):
assert (
"internal_weights" in weights.keys()
or "mod_weights" in weights.keys()
)
if "internal_weights" in weights.keys():
int_weights = weights["internal_weights"]
if "mod_weights" in weights.keys():
cm_weights = weights["mod_weights"]
else:
if self._use_context_mod and len(weights) == n_cm:
cm_weights = weights
else:
assert len(weights) == len(self.param_shapes)
if self._use_context_mod:
cm_weights = weights[:n_cm]
int_weights = weights[n_cm:]
else:
int_weights = weights
if self._use_context_mod and cm_weights is None:
if self._context_mod_no_weights:
raise Exception(
"Network was generated without weights "
+ "for context-mod layers. Hence, they must be passed "
+ 'via the "weights" option.'
)
cm_weights = self.weights[:n_cm]
if int_weights is None:
if self._no_weights:
raise Exception(
"Network was generated without internal "
+ "weights. Hence, they must be passed via the "
+ '"weights" option.'
)
if self._context_mod_no_weights:
int_weights = self.weights
else:
int_weights = self.weights[n_cm:]
# Note, context-mod weights might have different shapes, as they
# may be parametrized on a per-sample basis.
if self._use_context_mod:
assert len(cm_weights) == len(self._context_mod_shapes)
int_shapes = self.param_shapes[n_cm:]
assert len(int_weights) == len(int_shapes)
for i, s in enumerate(int_shapes):
assert np.all(np.equal(s, list(int_weights[i].shape)))
cm_ind = 0
bn_ind = 0
if self._use_batch_norm:
n_bn = 2 * len(self.batchnorm_layers)
bn_weights = int_weights[:n_bn]
layer_weights = int_weights[n_bn:]
else:
layer_weights = int_weights
w_weights = []
b_weights = []
for i, p in enumerate(layer_weights):
if self.has_bias and i % 2 == 1:
b_weights.append(p)
else:
w_weights.append(p)
########################
### Parse condition ###
#######################
bn_cond = None
cmod_cond = None
if condition is not None:
if isinstance(condition, dict):
assert (
"bn_stats_id" in condition.keys()
or "cmod_ckpt_id" in condition.keys()
)
if "bn_stats_id" in condition.keys():
bn_cond = condition["bn_stats_id"]
if "cmod_ckpt_id" in condition.keys():
cmod_cond = condition["cmod_ckpt_id"]
# FIXME We always require context-mod weight above, but
# we can't pass both (a condition and weights) to the
# context-mod layers.
# An unelegant solution would be, to just set all
# context-mod weights to None.
raise NotImplementedError("CM-conditions not implemented!")
else:
bn_cond = condition
######################################
### Select batchnorm running stats ###
######################################
if self._use_batch_norm:
nn = len(self._batchnorm_layers)
running_means = [None] * nn
running_vars = [None] * nn
if distilled_params is not None:
if not self._distill_bn_stats:
raise ValueError(
'Argument "distilled_params" can only be '
+ "provided if the return value of "
+ 'method "distillation_targets()" is not None.'
)
shapes = self.hyper_shapes_distilled
assert len(distilled_params) == len(shapes)
for i, s in enumerate(shapes):
assert np.all(np.equal(s, list(distilled_params[i].shape)))
# Extract batchnorm stats from distilled_params
for i in range(0, len(distilled_params), 2):
running_means[i // 2] = distilled_params[i]
running_vars[i // 2] = distilled_params[i + 1]
elif self._use_batch_norm and self._bn_track_stats and bn_cond is None:
for i, bn_layer in enumerate(self._batchnorm_layers):
running_means[i], running_vars[i] = bn_layer.get_stats()
###########################
### Forward Computation ###
###########################
hidden = x
# Context-dependent modulation of inputs directly.
if self._use_context_mod and self._context_mod_inputs:
hidden = self._context_mod_layers[cm_ind].forward(
hidden,
weights=cm_weights[2 * cm_ind : 2 * cm_ind + 2],
ckpt_id=cmod_cond,
)
cm_ind += 1
for l in range(len(w_weights)):
W = w_weights[l]
if self.has_bias:
b = b_weights[l]
else:
b = None
# Linear layer.
if i == (len(weights) - 1):
features = deepcopy(hidden)
hidden = self._spec_norm(F.linear(hidden, W, bias=b))
# Only for hidden layers.
if l < len(w_weights) - 1:
# Context-dependent modulation (pre-activation).
if (
self._use_context_mod
and not self._context_mod_post_activation
):
hidden = self._context_mod_layers[cm_ind].forward(
hidden,
weights=cm_weights[2 * cm_ind : 2 * cm_ind + 2],
ckpt_id=cmod_cond,
)
cm_ind += 1
# Batch norm
if self._use_batch_norm:
hidden = self._batchnorm_layers[bn_ind].forward(
hidden,
running_mean=running_means[bn_ind],
running_var=running_vars[bn_ind],
weight=bn_weights[2 * bn_ind],
bias=bn_weights[2 * bn_ind + 1],
stats_id=bn_cond,
)
bn_ind += 1
# Dropout
if self._dropout_rate != -1:
hidden = self._dropout(hidden)
# Non-linearity
if self._a_fun is not None:
hidden = self._a_fun(hidden)
# Context-dependent modulation (post-activation).
if self._use_context_mod and self._context_mod_post_activation:
hidden = self._context_mod_layers[cm_ind].forward(
hidden,
weights=cm_weights[2 * cm_ind : 2 * cm_ind + 2],
ckpt_id=cmod_cond,
)
cm_ind += 1
# Context-dependent modulation in output layer.
if self._use_context_mod and not self._no_last_layer_context_mod:
hidden = self._context_mod_layers[cm_ind].forward(
hidden,
weights=cm_weights[2 * cm_ind : 2 * cm_ind + 2],
ckpt_id=cmod_cond,
)
if self._out_fn is not None:
return self._out_fn(hidden), hidden
return [hidden, features]
def distillation_targets(self):
"""Targets to be distilled after training.
See docstring of abstract super method
:meth:`mnets.mnet_interface.MainNetInterface.distillation_targets`.
This method will return the current batch statistics of all batch
normalization layers if ``distill_bn_stats`` and ``use_batch_norm``
was set to ``True`` in the constructor.
Returns:
The target tensors corresponding to the shapes specified in
attribute :attr:`hyper_shapes_distilled`.
"""
if self.hyper_shapes_distilled is None:
return None
ret = []
for bn_layer in self._batchnorm_layers:
ret.extend(bn_layer.get_stats())
return ret
@staticmethod
def weight_shapes(n_in=1, n_out=1, hidden_layers=[10, 10], use_bias=True):
"""Compute the tensor shapes of all parameters in a fully-connected
network.
Args:
n_in: Number of inputs.
n_out: Number of output units.
hidden_layers: A list of ints, each number denoting the size of a
hidden layer.
use_bias: Whether the FC layers should have biases.
Returns:
A list of list of integers, denoting the shapes of the individual
parameter tensors.
"""
shapes = []
prev_dim = n_in
layer_out_sizes = hidden_layers + [n_out]
for i, size in enumerate(layer_out_sizes):
shapes.append([size, prev_dim])
if use_bias:
shapes.append([size])
prev_dim = size
return shapes
if __name__ == "__main__":
pass