Skip to content

Commit

Permalink
Prevent torch FutureWarning torch.load with weights_only=False wh…
Browse files Browse the repository at this point in the history
…en loading checkpoints
  • Loading branch information
Toni-SM committed Sep 10, 2024
1 parent 6851468 commit 8d84bbf
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
6 changes: 5 additions & 1 deletion skrl/agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import gym
import gymnasium
from packaging import version

import numpy as np
import torch
Expand Down Expand Up @@ -374,7 +375,10 @@ def load(self, path: str) -> None:
:param path: Path to load the model from
:type path: str
"""
modules = torch.load(path, map_location=self.device)
if version.parse(torch.__version__) >= version.parse("1.13"):
modules = torch.load(path, map_location=self.device, weights_only=False) # prevent torch:FutureWarning
else:
modules = torch.load(path, map_location=self.device)
if type(modules) is dict:
for name, data in modules.items():
module = self.checkpoint_modules.get(name, None)
Expand Down
7 changes: 6 additions & 1 deletion skrl/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import gym
import gymnasium
from packaging import version

import numpy as np
import torch
Expand Down Expand Up @@ -488,7 +489,11 @@ def load(self, path: str) -> None:
>>> model = Model(observation_space, action_space, device="cuda:1")
>>> model.load("model.pt")
"""
self.load_state_dict(torch.load(path, map_location=self.device))
if version.parse(torch.__version__) >= version.parse("1.13"):
state_dict = torch.load(path, map_location=self.device, weights_only=False) # prevent torch:FutureWarning
else:
state_dict = torch.load(path, map_location=self.device)
self.load_state_dict(state_dict)
self.eval()

def migrate(self,
Expand Down
7 changes: 6 additions & 1 deletion skrl/models/torch/tabular.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Mapping, Optional, Sequence, Tuple, Union

from packaging import version

import torch

from skrl.models.torch import Model
Expand Down Expand Up @@ -196,7 +198,10 @@ def load(self, path: str) -> None:
>>> model = Model(observation_space, action_space, device="cuda:1")
>>> model.load("model.pt")
"""
tensors = torch.load(path)
if version.parse(torch.__version__) >= version.parse("1.13"):
tensors = torch.load(path, weights_only=False) # prevent torch:FutureWarning
else:
tensors = torch.load(path)
for name, tensor in tensors.items():
if hasattr(self, name) and isinstance(getattr(self, name), torch.Tensor):
_tensor = getattr(self, name)
Expand Down
6 changes: 5 additions & 1 deletion skrl/multi_agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import gym
import gymnasium
from packaging import version

import numpy as np
import torch
Expand Down Expand Up @@ -393,7 +394,10 @@ def load(self, path: str) -> None:
:param path: Path to load the model from
:type path: str
"""
modules = torch.load(path, map_location=self.device)
if version.parse(torch.__version__) >= version.parse("1.13"):
modules = torch.load(path, map_location=self.device, weights_only=False) # prevent torch:FutureWarning
else:
modules = torch.load(path, map_location=self.device)
if type(modules) is dict:
for uid in self.possible_agents:
if uid not in modules:
Expand Down

0 comments on commit 8d84bbf

Please sign in to comment.