Skip to content

Commit

Permalink
Calling to() is recursive on the submodules of an nn.Module
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Jan 15, 2024
1 parent 650b509 commit cc20b74
Showing 1 changed file with 0 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str,
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

self.model.to(device)
# TODO: this should be captured by the model.to(device) statement
self.model.transformer.to(device)
logging.info(f"Training model on device: {device}. {next(self.model.parameters()).device}")

if os.path.exists(save_name):
Expand Down

0 comments on commit cc20b74

Please sign in to comment.