Skip to content

Commit

Permalink
Merge pull request #180 from IBM/fix/wavelet_activation
Browse files Browse the repository at this point in the history
Fix/wavelet activation
  • Loading branch information
Joao-L-S-Almeida authored Feb 6, 2024
2 parents cf0dc64 + 94bede5 commit 8b67528
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
9 changes: 9 additions & 0 deletions simulai/models/_pytorch_models/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
decoder_mlp_layer_config: dict = None,
number_of_encoders: int = 1,
number_of_decoders: int = 1,
devices: Union[str, list] = "cpu",
) -> None:
r"""A classical encoder-decoder transformer:
Expand Down Expand Up @@ -229,6 +230,9 @@ def __init__(
self.encoder_mlp_layers_list = list()
self.decoder_mlp_layers_list = list()

#Determining the kind of device in which the modelwill be executed
self.device = self._set_device(devices=devices)

# Creating independent copies for the MLP layers which will be used
# by the multiple encoders/decoders.
for e in range(self.number_of_encoders):
Expand Down Expand Up @@ -281,6 +285,11 @@ def __init__(
self.final_layer = Linear(input_size=self.embed_dim_decoder, output_size=self.output_dim)
self.add_module("final_linear_layer", self.final_layer)

# Sending everything to the proper device
self.EncoderStage = self.EncoderStage.to(self.device)
self.DecoderStage = self.DecoderStage.to(self.device)
self.final_layer = self.final_layer.to(self.device)

@as_tensor
def forward(
self, input_data: Union[torch.Tensor, np.ndarray] = None
Expand Down
3 changes: 1 addition & 2 deletions simulai/templates/_pytorch_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _setup_activations(
activation_op = self._get_operation(operation=activation)

return (
(n_layers - 1) * [activation_op]
[self._get_operation(operation=activation) for i in range(n_layers - 1)]
+ [self._get_operation(operation=self.default_last_activation)],
(n_layers - 1) * [activation] + [self.default_last_activation],
)
Expand Down Expand Up @@ -212,7 +212,6 @@ def _setup_activations(

return activations_list, activation


else:
raise Exception(
"The activation format,"
Expand Down

0 comments on commit 8b67528

Please sign in to comment.