Skip to content

Commit

Permalink
Merge pull request #184 from IBM/fix/wavelet_activation
Browse files Browse the repository at this point in the history
The activation parameters were not really being properly device placed.
  • Loading branch information
Joao-L-S-Almeida authored Feb 12, 2024
2 parents 948cab6 + 1843604 commit 8c8b499
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion simulai/models/_pytorch_models/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from typing import Union, Tuple

from simulai.templates import NetworkTemplate, as_tensor
from simulai.templates import NetworkTemplate, as_tensor, guarantee_device
from simulai.regression import DenseNetwork, Linear


Expand Down
2 changes: 1 addition & 1 deletion simulai/regression/_pytorch/_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(
"""

super(DenseNetwork, self).__init__()
super(DenseNetwork, self).__init__(**kwargs)

assert layers_units, "Please, set a list of units for each layer"

Expand Down
7 changes: 4 additions & 3 deletions simulai/templates/_pytorch_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def __init__(self, name: str = None, devices: str = None) -> None:

self.shapes_dict = None
self.device_type = devices
self.device = self._set_device(devices=devices)

if self.device_type:
if self.device_type != "cpu":
self.to_wrap = self._to_explicit_device
else:
self.to_wrap = self._to_bypass
Expand Down Expand Up @@ -148,7 +149,7 @@ def _get_operation(
if torch.nn.Module in res_.__mro__:
res = res_
print(f"Module {operation} found in {engine}.")
return res()
return res(**kwargs)
else:
print(f"Module {operation} not found in {engine}.")
else:
Expand All @@ -175,7 +176,7 @@ def _setup_activations(
if isinstance(activation_op, simulact.TrainableActivation):

activations_list = [self._get_operation(operation=activation,
is_activation=True, device=self.device_type)
is_activation=True, device=self.device)
for i in range(n_layers - 1)]

else:
Expand Down

0 comments on commit 8c8b499

Please sign in to comment.