diff --git a/train/Train.lua b/train/Train.lua index 8cf45f9..987f2cd 100644 --- a/train/Train.lua +++ b/train/Train.lua @@ -14,6 +14,8 @@ function Train:__init(loss_wrapper, batcher, optimization_config, general_config self.model = optimization_config.modules_to_update + self.gradient_clip = optimization_config.gradient_clip + self.clamp_gradient = self.gradient_clip > 0 self.parameters, self.grad_parameters = self.model:getParameters() self.data_for_callbacks = { parameters = self.parameters,