Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 20, 2024
2 parents 40390f5 + bde9f05 commit fafa7bd
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4996,8 +4996,14 @@ def __init__(
)
kwargs = primers
if not isinstance(kwargs, Composite):
kwargs = Composite(**kwargs)
self.primers = kwargs
shape = kwargs.pop("shape", None)
device = kwargs.pop("device", None)
if "batch_size" in kwargs.keys():
extra_kwargs = {"batch_size": kwargs.pop("batch_size")}
else:
extra_kwargs = {}
primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs)
self.primers = primers
self.expand_specs = expand_specs

if random and default_value:
Expand Down

0 comments on commit fafa7bd

Please sign in to comment.