-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Function make_tensordict_primer
Overlooks Batch-Locked Envs
#2323
Comments
Not exactly sure what is going on in your code but this works fine with me from torchrl.collectors import SyncDataCollector
from torchrl.envs import TransformedEnv, InitTracker
from torchrl.envs import GymEnv
from torchrl.modules import MLP, LSTMModule
from torch import nn
from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
assert env.base_env.batch_locked
lstm_module = LSTMModule(
input_size=env.observation_spec["observation"].shape[-1],
hidden_size=64,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
mlp = MLP(num_cells=[64], out_features=1)
policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
policy(env.reset())
env = env.append_transform(lstm_module.make_tensordict_primer())
data_collector = SyncDataCollector(
env,
policy,
frames_per_batch=10
)
for data in data_collector:
print(data)
break and the env is batch-locked. Can you give a runnable example perhaps? |
Thanks for the reply! Sorry for mis-using the word "batch-locked", I meant vectorized environments with batch size larger than one. My env is a derived env from A rollout of 16 steps and 64 envs looks like this: original_env = MyEnv(
cfg=cfg_dict["task"],
device=cfg["sim_device"],
graphics_device_id=cfg["graphics_device_id"],
headless=cfg["headless"],
force_render=cfg["force_render"],
)
check_env_specs(env)
target_num_steps = 200
t0 = time.time()
with torch.no_grad():
rollout_data = env.rollout(target_num_steps)
t1 = time.time()
print(rollout_data) 2024-07-26 18:34:09,664 [torchrl][INFO] check_env_specs succeeded!
TensorDict(
fields={
action: Tensor(shape=torch.Size([64, 16, 4]), device=cuda:0, dtype=torch.float32, is_shared=True),
done: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
observation: TensorDict(
fields={
depth_image: Tensor(shape=torch.Size([64, 16, 1, 192, 256]), device=cuda:0, dtype=torch.float32, is_shared=True),
drone_state: Tensor(shape=torch.Size([64, 16, 18]), device=cuda:0, dtype=torch.float32, is_shared=True),
batch_size=torch.Size([64, 16]),
device=cuda:0,
is_shared=True),
reward: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
terminated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
truncated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
batch_size=torch.Size([64, 16]),
device=cuda:0,
is_shared=True),
observation: TensorDict(
fields={
depth_image: Tensor(shape=torch.Size([64, 16, 1, 192, 256]), device=cuda:0, dtype=torch.float32, is_shared=True),
drone_state: Tensor(shape=torch.Size([64, 16, 18]), device=cuda:0, dtype=torch.float32, is_shared=True),
batch_size=torch.Size([64, 16]),
device=cuda:0,
is_shared=True),
terminated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
truncated: Tensor(shape=torch.Size([64, 16, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
batch_size=torch.Size([64, 16]),
device=cuda:0,
is_shared=True) I also tried to modify your code with from torchrl.collectors import SyncDataCollector
from torchrl.envs import TransformedEnv, InitTracker
from torchrl.envs import GymEnv
from torchrl.modules import MLP, LSTMModule
import torch
from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
from torchrl.envs import ParallelEnv
def env_make():
return GymEnv("Pendulum-v1")
env = TransformedEnv(ParallelEnv(3, env_make), InitTracker())
assert env.base_env.batch_locked
lstm_module = LSTMModule(
input_size=env.observation_spec["observation"].shape[-1],
hidden_size=64,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")],
)
mlp = MLP(num_cells=[64], out_features=1)
policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
policy(env.reset())
env = env.append_transform(lstm_module.make_tensordict_primer())
data_collector = SyncDataCollector(env, policy, frames_per_batch=10)
for data in data_collector:
print(data)
break but it gave me errors when executing ...
File "/home/lyq/Developer/isaacgym_workspace/scratch/test_td_primer.py", line 25, in <module>
policy(env.reset())
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 809, in _reset
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 56, in decorated_fun
self._start_workers()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 1275, in _start_workers
process.start()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/context.py", line 224, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 42, in _launch
prep_data = spawn.get_preparation_data(process_obj._name)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/spawn.py", line 154, in get_preparation_data
_check_not_importing_main()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/spawn.py", line 134, in _check_not_importing_main
raise RuntimeError('''
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable. Maybe I am not using parallel envs correctly... So I am afraid that I can't give you a running example, but I hope these tests can provide more info for the issue. |
Your rollouts looks ok to me. The error you're seeing with ParallelEnv should be solved if you put it in a from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
from torchrl.collectors import SyncDataCollector
from torchrl.envs import GymEnv
from torchrl.envs import TransformedEnv, InitTracker, ParallelEnv, SerialEnv
from torchrl.modules import MLP, GRUModule
def make_env():
return TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
if __name__ == "__main__":
gru_module = GRUModule(
input_size=make_env().observation_spec["observation"].shape[-1],
hidden_size=64,
in_keys=["observation", "rs"],
out_keys=["intermediate", ("next", "rs")])
mlp = MLP(num_cells=[64], out_features=1)
policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
primer = gru_module.make_tensordict_primer()
env = ParallelEnv(2,
lambda primer=primer:
SerialEnv(3, make_env).append_transform(primer.clone()))
reset = env.reset()
print('reset', reset)
policy(reset)
print('reset after policy', reset)
data_collector = SyncDataCollector(
env,
policy,
frames_per_batch=10
)
for data in data_collector:
print("data from rollout", data)
break Since I can't repro I'm closing this, but if you still encounter an error feel free to re-open |
Thx for the example. But I get errors and cannot complete the rollout. Is it because I am using version 11:59:40 |mujoco|lyq@xpg scratch → python test_td_primer.py
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-gpackages/torchrl/envs/common.py:2989: DeprecationWarning: Your wrapper was not given a device. Currently, this value will default to 'cpu'. From v0.5 it will default to `None`. With a device of None, no device casting is performed and the resulting tensordicts are deviceless. Please set your device accordingly.
warnings.warn(
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
warnings.warn('Lazy modules are a new feature under heavy development '
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py:2989: DeprecationWarning: Your wrapper was not given a device. Currently, this value will default to 'cpu'. From v0.5 it will default to `None`. With a device of None, no device casting is performed and the resulting tensordicts are deviceless. Please set your device accordingly.
warnings.warn(
Process _ProcessNoWarn-2:
Traceback (most recent call last):
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/_utils.py", line 668, in run
return mp.Process.run(self, *args, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 1765, in _run_worker_pipe_shared_mem
cur_td = env.reset(
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 814, in _reset
tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 4723, in _reset
expand_as_right(_reset, value), value, prev_val
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/tensordict/utils.py", line 331, in expand_as_right
raise RuntimeError(
RuntimeError: tensor shape is incompatible with dest shape, got: tensor.shape=torch.Size([3]), dest=torch.Size([1, 64])
Process _ProcessNoWarn-1:
Traceback (most recent call last):
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/_utils.py", line 668, in run
return mp.Process.run(self, *args, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 1765, in _run_worker_pipe_shared_mem
cur_td = env.reset(
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 814, in _reset
tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 4723, in _reset
expand_as_right(_reset, value), value, prev_val
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/tensordict/utils.py", line 331, in expand_as_right
raise RuntimeError(
RuntimeError: tensor shape is incompatible with dest shape, got: tensor.shape=torch.Size([3]), dest=torch.Size([1, 64])
reset TensorDict(
fields={
done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
is_init: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([2, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
rs: Tensor(shape=torch.Size([2, 3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=cpu,
is_shared=False)
reset after policy TensorDict(
fields={
action: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
intermediate: Tensor(shape=torch.Size([2, 3, 64]), device=cpu, dtype=torch.float32, is_shared=False),
is_init: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
rs: Tensor(shape=torch.Size([2, 3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([2, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False),
rs: Tensor(shape=torch.Size([2, 3, 1, 64]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([2, 3]),
device=cpu,
is_shared=False)
/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/collectors/collectors.py:618: UserWarning: frames_per_batch (10) is not exactly divisible by the number of batched environments (6), this results in more frames_per_batch per iteration that requested (12).To silence this message, set the environment variable RL_WARNINGS to False.
warnings.warn(
Traceback (most recent call last):
File "test_td_primer.py", line 63, in <module>
data_collector = SyncDataCollector(
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/collectors/collectors.py", line 633, in __init__
self._shuttle = self.env.reset()
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/common.py", line 2120, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/envs/batched_envs.py", line 59, in decorated_fun
_check_for_faulty_process(self._workers)
File "/home/lyq/mambaforge/envs/mujoco/lib/python3.8/site-packages/torchrl/_utils.py", line 124, in _check_for_faulty_process
raise RuntimeError(
RuntimeError: At least one process failed. Check for more infos in the log. |
Could be! I'll make the 0.5 release soon (somewhere this week), let's see if it fixes it (you can also use nightlies to check!) |
In v0.5, |
Describe the bug
Appending the
TensorDictPrimer
transform created fromLSTMModule.make_tensordict_primer
triggers dimension error.To Reproduce
Expected behavior
Should run without error, or when appending the
TensorDictPrimer
transform, its shapes should be checked.System info
torchrl==0.4
Additional context
There is a similar issue on: #1493
Reason and Possible fixes
Function
make_tensordict_primer
overlooks batch-locked envs as it's source code is:In this case users can manually add the transform with proper shapes.
Checklist
The text was updated successfully, but these errors were encountered: