Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 20, 2024
2 parents 91c0305 + 8df1e5c commit ed4493d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
10 changes: 9 additions & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)

# mixed precision training
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn.utils import clip_grad_norm_
from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -321,6 +321,14 @@ def compile_rssms(module):

t_collect_init = time.time()

test_env.close()
train_env.close()
collector.shutdown()

del test_env
del train_env
del collector


if __name__ == "__main__":
main()
17 changes: 14 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __call__(cls, *args, **kwargs):
# will be called before or after the specs, batch size etc are set.
_ = instance.done_spec
_ = instance.reward_keys
_ = instance.action_keys
# _ = instance.action_keys
_ = instance.state_spec
if auto_reset:
from torchrl.envs.transforms.transforms import (
Expand Down Expand Up @@ -659,7 +659,7 @@ def action_keys(self) -> List[NestedKey]:
action_keys = self.__dict__.get("_action_keys")
if action_keys is not None:
return action_keys
keys = self.input_spec["full_action_spec"].keys(True, True)
keys = self.full_action_spec.keys(True, True)
if not len(keys):
raise AttributeError("Could not find action spec")
keys = sorted(keys, key=_repr_by_depth)
Expand Down Expand Up @@ -1072,7 +1072,18 @@ def full_reward_spec(self) -> Composite:
domain=continuous), device=None, shape=torch.Size([])), device=cpu, shape=torch.Size([]))
"""
return self.output_spec["full_reward_spec"]
try:
return self.output_spec["full_reward_spec"]
except KeyError:
# populate the "reward" entry
# this will be raised if there is not full_reward_spec (unlikely) or no reward_key
# Since output_spec is lazily populated with an empty composite spec for
# reward_spec, the second case is much more likely to occur.
self.reward_spec = Unbounded(
shape=(*self.batch_size, 1),
device=self.device,
)
return self.output_spec["full_reward_spec"]

@full_reward_spec.setter
def full_reward_spec(self, spec: Composite) -> None:
Expand Down

0 comments on commit ed4493d

Please sign in to comment.