diff --git a/simulai/models/_pytorch_models/_transformer.py b/simulai/models/_pytorch_models/_transformer.py index b43106c..32faf8b 100644 --- a/simulai/models/_pytorch_models/_transformer.py +++ b/simulai/models/_pytorch_models/_transformer.py @@ -8,13 +8,14 @@ class BaseTemplate(NetworkTemplate): - def __init__(self): + def __init__(self, device:str="cpu"): """Template used for sharing fundamental methods with the children transformer-like encoders and decoders. """ super(BaseTemplate, self).__init__() + self.device = device def _activation_getter( self, activation: Union[str, torch.nn.Module] @@ -33,7 +34,9 @@ def _activation_getter( if isinstance(activation, torch.nn.Module): return encoder_activation elif isinstance(activation, str): - return self._get_operation(operation=activation, is_activation=True) + act = self._get_operation(operation=activation, is_activation=True) + act.setup(device=self.device) + return act else: raise Exception(f"The activation {activation} is not supported.") @@ -45,6 +48,7 @@ def __init__( activation: Union[str, torch.nn.Module] = "relu", mlp_layer: torch.nn.Module = None, embed_dim: Union[int, Tuple] = None, + device:str="cpu", ) -> None: """Generic transformer encoder. @@ -56,7 +60,7 @@ def __init__( """ - super(BasicEncoder, self).__init__() + super(BasicEncoder, self).__init__(device=device) self.num_heads = num_heads @@ -107,6 +111,7 @@ def __init__( activation: Union[str, torch.nn.Module] = "relu", mlp_layer: torch.nn.Module = None, embed_dim: Union[int, Tuple] = None, + device:str="cpu", ): """Generic transformer decoder. @@ -253,6 +258,7 @@ def __init__( activation=self.encoder_activation, mlp_layer=self.encoder_mlp_layers_list[e], embed_dim=self.embed_dim_encoder, + device=self.device, ) for e in range(self.number_of_encoders) ] @@ -266,6 +272,7 @@ def __init__( activation=self.decoder_activation, mlp_layer=self.decoder_mlp_layers_list[d], embed_dim=self.embed_dim_decoder, + device=self.device, ) for d in range(self.number_of_decoders) ] diff --git a/simulai/templates/_pytorch_network.py b/simulai/templates/_pytorch_network.py index b5d93e8..03c4e17 100644 --- a/simulai/templates/_pytorch_network.py +++ b/simulai/templates/_pytorch_network.py @@ -169,10 +169,25 @@ def _setup_activations( # It instantiates an operation x^l = \sigma(y^l), in which y^l # is the output of the previous linear operation. if isinstance(activation, str): + # Testing to instantiate an example of activation function. activation_op = self._get_operation(operation=activation) + if isinstance(activation_op, simulact.TrainableActivation): + + activations_list = [self._get_operation(operation=activation, + is_activation=True) + for i in range(n_layers - 1)] + + for aa, act in enumerate(activations_list): + act.setup(device=self.device_type) + activations_list[aa] = act + + else: + activations_list = [self._get_operation(operation=activation) + for i in range(n_layers - 1)] + return ( - [self._get_operation(operation=activation) for i in range(n_layers - 1)] + activations_list + [self._get_operation(operation=self.default_last_activation)], (n_layers - 1) * [activation] + [self.default_last_activation], )