diff --git a/beacon8/layers/Module.py b/beacon8/layers/Module.py index 38155c4..396d720 100644 --- a/beacon8/layers/Module.py +++ b/beacon8/layers/Module.py @@ -7,9 +7,12 @@ class Module: def __init__(self): self.training_mode = True - self.fn_forward = None - self.fn_accum_grads = None - self.fn_accum_stats = None + # The functions are stored in a dictionary whose keys correspond to the + # values that `self.training_mode` can take. That way, it would be + # trivial to extend to further modes, and the code avoids many branches. + self.fn_forward = {} + self.fn_accum_grads = {} + self.fn_accum_stats = {} #def __hash__(self): # raise NotImplementedError("You *need* to reimplement hash, even if it's just python's default. See the documentation for more info.") @@ -44,15 +47,18 @@ def symb_forward(self, symb_input): raise NotImplementedError def forward(self, data): - if self.fn_forward is None: + if self.training_mode not in self.fn_forward: symb_in = _T.TensorType(_th.config.floatX, (False,) * data.ndim)('X') symb_out = self.symb_forward(symb_in) - self.fn_forward = _th.function(inputs=[symb_in], outputs=symb_out) + self.fn_forward[self.training_mode] = _th.function( + inputs=[symb_in], + outputs=symb_out + ) - return self.fn_forward(data) + return self.fn_forward[self.training_mode](data) def accumulate_gradients(self, data_in, data_tgt, loss): - if self.fn_accum_grads is None: + if self.training_mode not in self.fn_accum_grads: symb_in = _T.TensorType(_th.config.floatX, (False,) * data_in.ndim)('X') symb_tgt = _T.TensorType(_th.config.floatX, (False,) * data_tgt.ndim)('T') symb_out = self.symb_forward(symb_in) @@ -62,18 +68,18 @@ def accumulate_gradients(self, data_in, data_tgt, loss): symb_grads = _th.grad(cost=symb_err, wrt=params) grads_updates = [(grad, grad + symb_grad) for grad, symb_grad in zip(grads, symb_grads)] - self.fn_accum_grads = _th.function( + self.fn_accum_grads[self.training_mode] = _th.function( inputs=[symb_in, symb_tgt], updates=grads_updates ) - self.fn_accum_grads(data_in, data_tgt) + self.fn_accum_grads[self.training_mode](data_in, data_tgt) def get_stat_updates(self): return [] def accumulate_statistics(self, data_in): - if self.fn_accum_stats is None: + if self.training_mode not in self.fn_accum_stats: symb_in = _T.TensorType(_th.config.floatX, (False,) * data_in.ndim)('X') self.symb_forward(symb_in) @@ -83,9 +89,9 @@ def accumulate_statistics(self, data_in): # compile and call a function. This prevents theano errors. return - self.fn_accum_stats = _th.function( + self.fn_accum_stats[self.training_mode] = _th.function( inputs=[symb_in], updates=stat_updates ) - self.fn_accum_stats(data_in) + self.fn_accum_stats[self.training_mode](data_in)