diff --git a/simulai/templates/_pytorch_network.py b/simulai/templates/_pytorch_network.py index 03c4e17..6c47f00 100644 --- a/simulai/templates/_pytorch_network.py +++ b/simulai/templates/_pytorch_network.py @@ -126,7 +126,7 @@ def _get_from_guest(self, **kwargs) -> None: # Getting up activation if it exists def _get_operation( - self, operation: str = None, is_activation: bool = True + self, operation: str = None, is_activation: bool = True, **kwargs, ) -> callable: mod_items = dir(self.engine) mod_items_lower = [item.lower() for item in mod_items] @@ -136,7 +136,7 @@ def _get_operation( operation_class = getattr(self.engine, operation_name) if is_activation is True: - return operation_class() + return operation_class(**kwargs) else: return operation_class else: @@ -175,13 +175,9 @@ def _setup_activations( if isinstance(activation_op, simulact.TrainableActivation): activations_list = [self._get_operation(operation=activation, - is_activation=True) + is_activation=True, device=self.device_type) for i in range(n_layers - 1)] - for aa, act in enumerate(activations_list): - act.setup(device=self.device_type) - activations_list[aa] = act - else: activations_list = [self._get_operation(operation=activation) for i in range(n_layers - 1)]