From 1843604cc10c8251831f232eb9e61762110f8ed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 12 Feb 2024 12:51:50 -0500 Subject: [PATCH] The activation parameters were not really being properly device placed. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- simulai/models/_pytorch_models/_transformer.py | 2 +- simulai/regression/_pytorch/_dense.py | 2 +- simulai/templates/_pytorch_network.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/simulai/models/_pytorch_models/_transformer.py b/simulai/models/_pytorch_models/_transformer.py index 32faf8b..4b1f641 100644 --- a/simulai/models/_pytorch_models/_transformer.py +++ b/simulai/models/_pytorch_models/_transformer.py @@ -3,7 +3,7 @@ import torch from typing import Union, Tuple -from simulai.templates import NetworkTemplate, as_tensor +from simulai.templates import NetworkTemplate, as_tensor, guarantee_device from simulai.regression import DenseNetwork, Linear diff --git a/simulai/regression/_pytorch/_dense.py b/simulai/regression/_pytorch/_dense.py index 7f3fccf..3445862 100644 --- a/simulai/regression/_pytorch/_dense.py +++ b/simulai/regression/_pytorch/_dense.py @@ -200,7 +200,7 @@ def __init__( """ - super(DenseNetwork, self).__init__() + super(DenseNetwork, self).__init__(**kwargs) assert layers_units, "Please, set a list of units for each layer" diff --git a/simulai/templates/_pytorch_network.py b/simulai/templates/_pytorch_network.py index 6c47f00..a49ccfa 100644 --- a/simulai/templates/_pytorch_network.py +++ b/simulai/templates/_pytorch_network.py @@ -55,8 +55,9 @@ def __init__(self, name: str = None, devices: str = None) -> None: self.shapes_dict = None self.device_type = devices + self.device = self._set_device(devices=devices) - if self.device_type: + if self.device_type != "cpu": self.to_wrap = self._to_explicit_device else: self.to_wrap = self._to_bypass @@ -148,7 +149,7 @@ def _get_operation( if torch.nn.Module in res_.__mro__: res = res_ print(f"Module {operation} found in {engine}.") - return res() + return res(**kwargs) else: print(f"Module {operation} not found in {engine}.") else: @@ -175,7 +176,7 @@ def _setup_activations( if isinstance(activation_op, simulact.TrainableActivation): activations_list = [self._get_operation(operation=activation, - is_activation=True, device=self.device_type) + is_activation=True, device=self.device) for i in range(n_layers - 1)] else: