Skip to content

Commit 0dc98d5

Browse files
felixy12feyu-bdai
authored andcommitted
[BugFix] Add reference to policy with state dict (#3043)
Co-authored-by: Felix Yu <[email protected]>
1 parent 227b5fc commit 0dc98d5

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

torchrl/collectors/collectors.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,10 @@ def __init__(
686686
policy = RandomPolicy(env.full_action_spec)
687687
elif policy_factory is not None:
688688
raise TypeError("policy_factory cannot be used with policy argument.")
689+
# If the underlying policy has a state_dict, we keep a reference to the policy and
690+
# do all policy weight saving/loading through it
691+
if hasattr(policy, "state_dict"):
692+
self._policy_w_state_dict = policy
689693

690694
if trust_policy is None:
691695
trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
@@ -1686,8 +1690,8 @@ def state_dict(self) -> OrderedDict:
16861690
else:
16871691
env_state_dict = OrderedDict()
16881692

1689-
if hasattr(self.policy, "state_dict"):
1690-
policy_state_dict = self.policy.state_dict()
1693+
if hasattr(self, "_policy_w_state_dict"):
1694+
policy_state_dict = self._policy_w_state_dict.state_dict()
16911695
state_dict = OrderedDict(
16921696
policy_state_dict=policy_state_dict,
16931697
env_state_dict=env_state_dict,
@@ -1711,7 +1715,13 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
17111715
if strict or "env_state_dict" in state_dict:
17121716
self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
17131717
if strict or "policy_state_dict" in state_dict:
1714-
self.policy.load_state_dict(state_dict["policy_state_dict"], **kwargs)
1718+
if not hasattr(self, "_policy_w_state_dict"):
1719+
raise ValueError(
1720+
"Underlying policy does not have state_dict to load policy_state_dict into."
1721+
)
1722+
self._policy_w_state_dict.load_state_dict(
1723+
state_dict["policy_state_dict"], **kwargs
1724+
)
17151725
self._frames = state_dict["frames"]
17161726
self._iter = state_dict["iter"]
17171727

0 commit comments

Comments
 (0)