Skip to content

Commit

Permalink
Devices defined at the Wavelet activation constructor
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Feb 9, 2024
1 parent a0b8ab5 commit 58c31bb
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions simulai/templates/_pytorch_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _get_from_guest(self, **kwargs) -> None:

# Getting up activation if it exists
def _get_operation(
self, operation: str = None, is_activation: bool = True
self, operation: str = None, is_activation: bool = True, **kwargs,
) -> callable:
mod_items = dir(self.engine)
mod_items_lower = [item.lower() for item in mod_items]
Expand All @@ -136,7 +136,7 @@ def _get_operation(
operation_class = getattr(self.engine, operation_name)

if is_activation is True:
return operation_class()
return operation_class(**kwargs)
else:
return operation_class
else:
Expand Down Expand Up @@ -175,13 +175,9 @@ def _setup_activations(
if isinstance(activation_op, simulact.TrainableActivation):

activations_list = [self._get_operation(operation=activation,
is_activation=True)
is_activation=True, device=self.device_type)
for i in range(n_layers - 1)]

for aa, act in enumerate(activations_list):
act.setup(device=self.device_type)
activations_list[aa] = act

else:
activations_list = [self._get_operation(operation=activation)
for i in range(n_layers - 1)]
Expand Down

0 comments on commit 58c31bb

Please sign in to comment.