Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 15, 2024
1 parent 67f659c commit b0fdbce
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
VERBOSE,
)
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.data.utils import contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import _EnvPostInit, EnvBase
from torchrl.envs.env_creator import get_env_metadata

Expand Down Expand Up @@ -308,6 +308,13 @@ def __init__(
self.policy_proof = policy_proof
self.num_workers = num_workers
self.create_env_fn = create_env_fn

from torchrl.envs.env_creator import EnvCreator

for i, env_fun in enumerate(self.create_env_fn):
if not isinstance(env_fun, EnvCreator) and not isinstance(env_fun, EnvBase):
self.create_env_fn[i] = EnvCreator(env_fun)

self.create_env_kwargs = create_env_kwargs
self.pin_memory = pin_memory
if pin_memory:
Expand Down Expand Up @@ -1050,7 +1057,6 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
"""

def _start_workers(self) -> None:
from torchrl.envs.env_creator import EnvCreator

if self.num_threads is None:
self.num_threads = max(
Expand Down Expand Up @@ -1089,8 +1095,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
# No certainty which module multiprocessing_context is
parent_pipe, child_pipe = ctx.Pipe()
env_fun = self.create_env_fn[idx]
if not isinstance(env_fun, EnvCreator):
env_fun = CloudpickleWrapper(env_fun)
kwargs[idx].update(
{
"parent_pipe": parent_pipe,
Expand Down

0 comments on commit b0fdbce

Please sign in to comment.