@@ -686,6 +686,10 @@ def __init__(
686
686
policy = RandomPolicy (env .full_action_spec )
687
687
elif policy_factory is not None :
688
688
raise TypeError ("policy_factory cannot be used with policy argument." )
689
+ # If the underlying policy has a state_dict, we keep a reference to the policy and
690
+ # do all policy weight saving/loading through it
691
+ if hasattr (policy , "state_dict" ):
692
+ self ._policy_w_state_dict = policy
689
693
690
694
if trust_policy is None :
691
695
trust_policy = isinstance (policy , (RandomPolicy , CudaGraphModule ))
@@ -1686,8 +1690,8 @@ def state_dict(self) -> OrderedDict:
1686
1690
else :
1687
1691
env_state_dict = OrderedDict ()
1688
1692
1689
- if hasattr (self . policy , "state_dict " ):
1690
- policy_state_dict = self .policy .state_dict ()
1693
+ if hasattr (self , "_policy_w_state_dict " ):
1694
+ policy_state_dict = self ._policy_w_state_dict .state_dict ()
1691
1695
state_dict = OrderedDict (
1692
1696
policy_state_dict = policy_state_dict ,
1693
1697
env_state_dict = env_state_dict ,
@@ -1711,7 +1715,13 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
1711
1715
if strict or "env_state_dict" in state_dict :
1712
1716
self .env .load_state_dict (state_dict ["env_state_dict" ], ** kwargs )
1713
1717
if strict or "policy_state_dict" in state_dict :
1714
- self .policy .load_state_dict (state_dict ["policy_state_dict" ], ** kwargs )
1718
+ if not hasattr (self , "_policy_w_state_dict" ):
1719
+ raise ValueError (
1720
+ "Underlying policy does not have state_dict to load policy_state_dict into."
1721
+ )
1722
+ self ._policy_w_state_dict .load_state_dict (
1723
+ state_dict ["policy_state_dict" ], ** kwargs
1724
+ )
1715
1725
self ._frames = state_dict ["frames" ]
1716
1726
self ._iter = state_dict ["iter" ]
1717
1727
0 commit comments