-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Problems with BatchedEnv on accelerated device with single envs on cpu #1864
Comments
To reproduce the bug on
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
EnvCreator,
ExplorationType,
StepCounter,
TransformedEnv,
ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
n_env = 4
env_id = "MountainCar-v0"
device = "cuda:0"
def build_cpu_single_env():
env = GymEnv(env_id, device="cpu")
env = TransformedEnv(env)
env.append_transform(StepCounter(max_steps=max_step, truncated_key="foo"))
return env
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=["observation"],
out_keys=["logits"],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=["logits"],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == "__main__":
env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
env = TransformedEnv(env)
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
for i in range(10):
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
max_step_count = batches["next", "step_count"].max().item()
if max_step_count > max_step:
print("Problem!")
print(max_step_count)
break
else:
print("No problem!") |
I can reprod the initial example iif the env is on "cpu" so it's likely just a problem of casting from device to device in serial env |
Nice, thanks! Indeed it's probably device casting gone wrong somewhere as MPS crashed with segfault literally. |
Can you have a go at 1866 for cpu envs? With me it works on sub-envs on cpu and cuda (even with 100 outer steps) |
Also I ran the second example with 10K outer iteration but could not reprod the issue (on the branch of the PR but I did not change much from main) so I'm not sure how to address this
|
VERBOSE=1 python -c """import tqdm
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
EnvCreator,
ExplorationType,
StepCounter,
TransformedEnv,
ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
n_env = 4
env_id = 'MountainCar-v0'
device = 'cuda:0'
... /usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
logger.warn(
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
logger.warn(
2024-02-02 13:10:48,627 [torchrl][INFO] resetting implement_for
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
warnings.warn('Lazy modules are a new feature under heavy development '
2024-02-02 13:10:48,872 [torchrl][INFO] initiating worker 0
2024-02-02 13:10:48,965 [torchrl][INFO] initiating worker 1
2024-02-02 13:10:48,967 [torchrl][INFO] initiating worker 2
2024-02-02 13:10:48,969 [torchrl][INFO] initiating worker 3
2024-02-02 13:10:51,116 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,142 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,192 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,196 [torchrl][INFO] resetting implement_for
0%| | 0/10000 [00:00<?, ?it/s]Problem!
3201372667
tensor([[[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[ 10],
[ 1],
[ 1],
[ 1]],
[[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[ 10],
[ 1],
[ 1],
[3201372667]],
[[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[ 10],
[ 1],
[ 1],
[ 1]],
[[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[ 10],
[ 1],
[ 1],
[ 1]]], device='cuda:0')
0%| ❯ python
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys
>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux I will try with the 1866 branch and on another cluster. |
I will the |
The problem is now different with VERBOSE=1 python -c """import tqdm
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
EnvCreator,
ExplorationType,
StepCounter,
TransformedEnv,
ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
n_env = 4
env_id = 'MountainCar-v0'
device = 'cuda:0'
def build_cpu_single_env():
env = GymEnv(env_id, device='cpu')
env = TransformedEnv(env)
env.append_transform(StepCounter(max_steps=max_step, truncated_key='foo'))
return env
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=['observation'],
out_keys=['logits'],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=['logits'],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == '__main__':
env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
env = TransformedEnv(env)
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
for i in tqdm.tqdm(range(10)):
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
max_step_count = batches['next', 'step_count'].max().item()
if max_step_count > max_step:
print('Problem 1!')
print(max_step_count)
print(batches['next', 'step_count'])
break
elif max_step_count < max_step:
print('Problem 2!')
print(max_step_count)
print(batches['next', 'step_count'])
break
else:
print('No problem!')
""" /usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
logger.warn(
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
logger.warn(
2024-02-02 13:44:03,727 [torchrl][INFO] resetting implement_for
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
warnings.warn('Lazy modules are a new feature under heavy development '
2024-02-02 13:44:03,981 [torchrl][INFO] initiating worker 0
2024-02-02 13:44:04,041 [torchrl][INFO] initiating worker 1
2024-02-02 13:44:04,043 [torchrl][INFO] initiating worker 2
2024-02-02 13:44:04,045 [torchrl][INFO] initiating worker 3
2024-02-02 13:44:06,233 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,241 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,255 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,273 [torchrl][INFO] resetting implement_for
10%|████████████████████▉ | 1/10 [00:00<00:07, 1.23it/s]Problem 2!
1
tensor([[[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1]],
[[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1]],
[[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1]],
[[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1],
[1]]], device='cuda:0')
10%|████████████████████▉ ❯ python
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys
>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+1ea3c74 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux I will try on a different cluster. |
Happens on two different clusters with different CPUs and GPUs (same Docker image though, the NVIDIA NGC PyTorch). |
On MPS it's not segfault anymore but the original arbitrary number bug: from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
EnvCreator,
ExplorationType,
StepCounter,
TransformedEnv,
SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"
def build_cpu_single_env():
env = GymEnv(env_id, device="cpu")
env = TransformedEnv(env)
env.append_transform(StepCounter(max_steps=max_step))
return env
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=["observation"],
out_keys=["logits"],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=["logits"],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == "__main__":
env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
for i in range(10):
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
max_step_count = batches["next", "step_count"].max().item()
print(max_step_count)
print(batches["next", "step_count"])
if max_step_count > max_step:
print("Problem!")
print(max_step_count)
break
else:
print("No problem!") gives /Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
warnings.warn('Lazy modules are a new feature under heavy development '
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
logger.warn(
5157210208
tensor([[[ 6],
[ 6],
[ 1],
[ 6],
[ 0],
[ 6],
[ 6],
[ 6],
[ 6],
[ 6],
[ 6],
[ 1],
[ 6]],
[[ 1],
[ 1],
[ 1],
[ 1],
[5157210208],
[ 1],
[ 2],
[ 1],
[ 1],
[ 2],
[ 1],
[ 2],
[ 1]],
[[ 1],
[ 1],
[ 1],
[ 1],
[ 6],
[ 1],
[ 2],
[ 1],
[ 1],
[ 2],
[ 1],
[ 2],
[ 1]],
[[ 1],
[ 1],
[ 1],
[ 1],
[ 1],
[ 1],
[ 2],
[ 1],
[ 1],
[ 2],
[ 1],
[ 2],
[ 1]]], device='mps:0')
Problem!
5157210208 |
MPS still gives segfault for |
I think it's solved now (for cuda on serial and parallel on the bugfix PR). |
Does this need a specific branch on tensordict? |
Yeah sorry I'm patching TensorDict let me quickly revert the latest changes which should be part of the another PR |
I changed it, and tests seem to be passing. If they all do, I'll do a final run of your examples and check the status on MPS. If it all runs smoothly I will consider the PR as good unless you wish to do a proper review of it. |
I'll poke a bit now and give my feedback soon. So I should test with this branch on TorchRL and main on Tensordict? |
Yes tensordict main is up to date |
All good for CUDA! Awesome! (tested some scripts but didn't check the PR code) ❯ python
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys
>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+99705db 0.4.0+1f485e9 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux (Played with variations of things like here https://github.com/skandermoalla/TorchRL/tree/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests) |
Testing for MPS. |
Not yet for MPS. For SerialEnv (https://github.com/skandermoalla/TorchRL/blob/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests/issue_env_device_serial.py) I have different errors that appear arbitrarily: python(22437,0x1d9879300) malloc: tiny_free_list_remove_ptr: Internal invariant broken (next ptr of prev): ptr=0x139ced580, prev_next=0x0
python(22437,0x1d9879300) malloc: *** set a breakpoint in malloc_error_break to debug
[1] 22437 abort python tests/issue_env_device_serial.py Traceback (most recent call last):
File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_serial.py", line 47, in <module>
batches = env.rollout((2 * max_step + 3), policy=policy_module, break_when_any_done=False)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
tensordicts = self._rollout_nonstop(**kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
tensordict_ = self.reset(tensordict_)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
return fun(self, *args, **kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 781, in _reset
_td = _env.reset(tensordict=tensordict_, **kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2071, in reset
return self._reset_proc_data(tensordict, tensordict_reset)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 795, in _reset_proc_data
self._reset_check_done(tensordict, tensordict_reset)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2103, in _reset_check_done
raise RuntimeError(
RuntimeError: Env done entry 'truncated' was (partially) True after reset on specified '_reset' dimensions. This is not allowed. and python(22700,0x16dcdb000) malloc: Incorrect checksum for freed object 0x13b396f08: probably modified after being freed.
Corrupt value: 0xbc99eb7cbad79154
python(22700,0x16dcdb000) malloc: *** set a breakpoint in malloc_error_break to debug
[1] 22700 abort python tests/issue_env_device_serial.py For ParallelEnv (https://github.com/skandermoalla/TorchRL/blob/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests/issue_env_device_parallel.py) I also have arbitrary errors Traceback (most recent call last):
File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 45, in <module>
policy_module(env.reset())
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2071, in reset
return self._reset_proc_data(tensordict, tensordict_reset)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 795, in _reset_proc_data
self._reset_check_done(tensordict, tensordict_reset)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2123, in _reset_check_done
raise RuntimeError(
RuntimeError: The done entry 'truncated' was (partially) True after a call to reset() in env TransformedEnv(
env=ParallelEnv(
env=TransformedEnv(
env=GymEnv(env=MountainCar-v0, batch_size=torch.Size([]), device=cpu),
transform=Compose(
StepCounter(keys=[]))),
batch_size=torch.Size([4])),
transform=Compose(
)). and Traceback (most recent call last):
File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 47, in <module>
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
tensordicts = self._rollout_nonstop(**kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
tensordict_ = self.reset(tensordict_)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 785, in _reset
tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
return fun(self, *args, **kwargs)
File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1298, in _reset
channel.send(out)
File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/multiprocessing/connection.py", line 206, in send
self._send_bytes(_ForkingPickler.dumps(obj))
File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 557, in reduce_storage
metadata = storage._share_filename_cpu_()
File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/storage.py", line 294, in wrapper
return fn(self, *args, **kwargs)
File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/storage.py", line 368, in _share_filename_cpu_
return super()._share_filename_cpu_(*args, **kwargs)
RuntimeError: _share_filename_: only available on CPU |
Those seem to be different issues than the CUDA ones, so I think we should go ahead with the PR and make sure these things are working ok with MPS separately! |
Reopening to keep track of progress with MPS |
If I avoid updating slices (see this bug) (which happens if you create a single copy of the env) I have no issue with the following code from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
EnvCreator,
ExplorationType,
StepCounter,
TransformedEnv,
SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"
def build_cpu_single_env():
env = GymEnv(env_id, device="cpu")
env = TransformedEnv(env)
env.append_transform(StepCounter(max_steps=max_step))
return env
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=["observation"],
out_keys=["logits"],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=["logits"],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == "__main__":
env = SerialEnv(n_env, [EnvCreator(build_cpu_single_env) for _ in range(n_env)], device=device)
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
for i in range(10):
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
max_step_count = batches["next", "step_count"].max().item()
print(max_step_count)
print(batches["next", "step_count"])
if max_step_count > max_step:
print("Problem!")
print(max_step_count)
break
else:
print("No problem!") Notice the way I create the serial env (I don't use the same EnvCreator) I also have no issue with ParallelEnv. |
Not for me. Running the above script gives me the same transient errors I described. Which commits are you using? I'm on the main branches of both TorchRL and TensorDict. torchrl ❯ mamba env export
name: torchrl
channels:
- pytorch
- conda-forge
dependencies:
- brotli=1.1.0=hb547adb_1
- brotli-bin=1.1.0=hb547adb_1
- bzip2=1.0.8=h93a5062_5
- ca-certificates=2023.11.17=hf0a4a13_0
- certifi=2023.11.17=pyhd8ed1ab_0
- colorama=0.4.6=pyhd8ed1ab_0
- contourpy=1.2.0=py310hd137fd4_0
- cycler=0.12.1=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- filelock=3.13.1=pyhd8ed1ab_0
- fonttools=4.47.2=py310hd125d64_0
- freetype=2.12.1=hadb7bae_2
- gmp=6.3.0=h965bd2d_0
- gmpy2=2.1.2=py310h2e6cad2_1
- imageio=2.33.1=pyh8c1a49c_0
- iniconfig=2.0.0=pyhd8ed1ab_0
- jinja2=3.1.3=pyhd8ed1ab_0
- kiwisolver=1.4.5=py310h38f39d4_1
- lcms2=2.16=ha0e7c42_0
- lerc=4.0.0=h9a09cb3_0
- libblas=3.9.0=21_osxarm64_openblas
- libbrotlicommon=1.1.0=hb547adb_1
- libbrotlidec=1.1.0=hb547adb_1
- libbrotlienc=1.1.0=hb547adb_1
- libcblas=3.9.0=21_osxarm64_openblas
- libcxx=16.0.6=h4653b0c_0
- libdeflate=1.19=hb547adb_0
- libffi=3.4.2=h3422bc3_5
- libgfortran=5.0.0=13_2_0_hd922786_2
- libgfortran5=13.2.0=hf226fd6_2
- libjpeg-turbo=3.0.0=hb547adb_1
- liblapack=3.9.0=21_osxarm64_openblas
- libopenblas=0.3.26=openmp_h6c19121_0
- libpng=1.6.39=h76d750c_0
- libsqlite=3.44.2=h091b4b1_0
- libtiff=4.6.0=ha8a6c65_2
- libwebp-base=1.3.2=hb547adb_0
- libxcb=1.15=hf346824_0
- libzlib=1.2.13=h53f4e23_5
- llvm-openmp=17.0.6=hcd81f8e_0
- markupsafe=2.1.4=py310hd125d64_0
- matplotlib=3.8.2=py310hb6292c7_0
- matplotlib-base=3.8.2=py310h9d2df84_0
- mpc=1.3.1=h91ba8db_0
- mpfr=4.2.1=h9546428_0
- mpmath=1.3.0=pyhd8ed1ab_0
- munkres=1.1.4=pyh9f0ad1d_0
- ncurses=6.4=h463b476_2
- networkx=3.2.1=pyhd8ed1ab_0
- numpy=1.26.3=py310hd45542a_0
- openjpeg=2.5.0=h4c1507b_3
- openssl=3.2.1=h0d3ecfb_0
- packaging=23.2=pyhd8ed1ab_0
- pcre2=10.42=h26f9a81_0
- pillow=10.2.0=py310hfae7ebd_0
- pip=23.3.2=pyhd8ed1ab_0
- pluggy=1.4.0=pyhd8ed1ab_0
- pthread-stubs=0.4=h27ca646_1001
- pyparsing=3.1.1=pyhd8ed1ab_0
- pytest=8.0.0=pyhd8ed1ab_0
- python=3.10.13=h2469fbe_1_cpython
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python_abi=3.10=4_cp310
- pytorch=2.2.0=py3.10_0
- pyyaml=6.0.1=py310h2aa6e3c_1
- readline=8.2=h92ec313_1
- setuptools=69.0.3=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- sympy=1.12=pypyh9d50eac_103
- tk=8.6.13=h5083fa2_1
- tomli=2.0.1=pyhd8ed1ab_0
- tornado=6.3.3=py310h2aa6e3c_1
- typing_extensions=4.9.0=pyha770c72_0
- tzdata=2023d=h0c530f3_0
- unicodedata2=15.1.0=py310h2aa6e3c_0
- wheel=0.42.0=pyhd8ed1ab_0
- xorg-libxau=1.0.11=hb547adb_0
- xorg-libxdmcp=1.1.3=h27ca646_0
- xz=5.2.6=h57fd34a_0
- yaml=0.2.5=h3422bc3_2
- zstd=1.5.5=h4f39d0f_0
- pip:
- absl-py==2.1.0
- ale-py==0.8.1
- annotated-types==0.6.0
- antlr4-python3-runtime==4.9.3
- appdirs==1.4.4
- attrs==23.2.0
- autorom==0.4.2
- autorom-accept-rom-license==0.6.1
- black==24.1.1
- box2d-py==2.3.5
- cfgv==3.4.0
- charset-normalizer==3.3.2
- click==8.1.7
- cloudpickle==3.0.0
- decorator==4.4.2
- distlib==0.3.8
- docker-pycreds==0.4.0
- etils==1.6.0
- farama-notifications==0.0.4
- fsspec==2023.12.2
- gitdb==4.0.11
- gitpython==3.1.41
- glfw==2.6.5
- gymnasium==0.29.1
- hydra-core==1.3.2
- identify==2.5.33
- idna==3.6
- imageio-ffmpeg==0.4.9
- importlib-resources==6.1.1
- joblib==1.3.2
- jsonref==1.1.0
- jsonschema==4.21.1
- jsonschema-specifications==2023.12.1
- moviepy==1.0.3
- mujoco==3.1.1
- mypy-extensions==1.0.0
- nodeenv==1.8.0
- omegaconf==2.3.0
- pathspec==0.12.1
- platformdirs==4.2.0
- pre-commit==3.6.0
- proglog==0.1.10
- protobuf==4.25.2
- psutil==5.9.8
- pydantic==2.6.0
- pydantic-core==2.16.1
- pygame==2.5.2
- pyopengl==3.1.7
- referencing==0.33.0
- requests==2.31.0
- rpds-py==0.17.1
- scikit-learn==1.4.0
- scipy==1.12.0
- sentry-sdk==1.40.0
- setproctitle==1.3.3
- shimmy==0.2.1
- smmap==5.0.1
- sweeps==0.2.0
- swig==4.1.1.post1
- threadpoolctl==3.2.0
- tqdm==4.66.1
- urllib3==2.2.0
- virtualenv==20.25.0
- wandb==0.16.2
- zipp==3.17.0 |
Can you check #1900 whenever you have time? |
Almost solved. It works with Serial and Parallel Env, but somehow breaks when a Transformed env is added on top of the ParallelEnv. from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
ExplorationType,
StepCounter,
TransformedEnv,
ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 10
n_env = 4
env_id = "MountainCar-v0"
device = "mps"
def build_cpu_single_env():
env = GymEnv(env_id, device="cpu")
env = TransformedEnv(env)
env.append_transform(StepCounter(max_steps=max_step, truncated_key="foo"))
return env
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=["observation"],
out_keys=["logits"],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=["logits"],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == "__main__":
# Works with both ParallelEnv and SerialEnv
env = ParallelEnv(n_env, lambda: build_cpu_single_env(), device=device)
# Breaks when adding a Transformed env on Parallel Env.
env = TransformedEnv(env)
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
max_step = min(max_step, 200)
for i in range(10):
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
max_step_count = batches["next", "step_count"].max().item()
# print(max_step_count)
# print(batches["next", "step_count"])
if max_step_count > max_step:
print("Problem 1!")
print(max_step_count)
print(batches["next", "step_count"])
break
elif max_step_count < max_step:
print("Problem 2!")
print(max_step_count)
print(batches["next", "step_count"])
break
else:
print("No problem!") Traceback (most recent call last):
File "/Users/moalla/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 48, in <module>
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
tensordicts = self._rollout_nonstop(**kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
tensordict_ = self.reset(tensordict_)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 785, in _reset
tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
return fun(self, *args, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1315, in _reset
self.shared_tensordicts[i].apply_(
File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 3597, in apply_
return self.apply(fn, *others, inplace=True, **kwargs)
File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 3692, in apply
return self._apply_nest(
File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_td.py", line 712, in _apply_nest
item_trsf = fn(item, *_others)
File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1312, in tentative_update
val.copy_(other)
RuntimeError: destOffset % 4 == 0 INTERNAL ASSERT FAILED at "/Users/runner/work/_temp/anaconda/conda-bld/pytorch_1704987091277/work/aten/src/ATen/native/mps/operations/Copy.mm":107, please report a bug to PyTorch. Unaligned blit request |
Actually there is another issue with ParallelEnv. The native truncation key is not faithful. from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import ExplorationType, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor
max_step = 210
n_env = 4
env_id = "MountainCar-v0"
NATIVE_TRUNCATION = 200
device = "mps"
max_step = min(max_step, NATIVE_TRUNCATION)
def build_cpu_single_env():
env = GymEnv(env_id, device="cpu")
env = TransformedEnv(env)
env.append_transform(StepCounter(max_steps=max_step))
return env
def build_actor(env):
return ProbabilisticActor(
module=TensorDictModule(
nn.LazyLinear(env.action_spec.space.n),
in_keys=["observation"],
out_keys=["logits"],
),
spec=env.action_spec,
distribution_class=OneHotCategorical,
in_keys=["logits"],
default_interaction_type=ExplorationType.RANDOM,
)
if __name__ == "__main__":
env = ParallelEnv(n_env, lambda: build_cpu_single_env(), device=device)
# env = TransformedEnv(env)
policy_module = build_actor(env)
policy_module.to(device)
policy_module(env.reset())
for i in range(10):
batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
max_step_count = batches["next", "step_count"].max().item()
if max_step_count > max_step:
print(max_step_count)
print(batches["next", "step_count"][:, -5:])
print("Problem! Got higher than max step count.")
break
elif max_step_count < max_step:
print(max_step_count)
print(batches["next", "step_count"][:, -5:])
print("Problem: Got less than max step count!")
break
else:
print(batches["next", "step_count"][:, -5:])
print("No problem!") 196
tensor([[[194],
[195],
[ 1],
[ 2],
[ 3]],
[[194],
[195],
[ 1],
[ 2],
[ 3]],
[[195],
[196],
[ 1],
[ 2],
[ 3]],
[[195],
[196],
[ 1],
[ 2],
[ 3]]], device='mps:0')
Problem: Got less than max step count! |
Ok this will need for me to have access to an mps device then (won't have one for the upcoming 3w I think) :/ |
This is now solved, right? |
Describe the bug
When the batched env device is
cuda
the step count on the batched env seems completely off from what it should be.When the batches env device is
mps
there is a segmentation fault.I wonder if this is only the step count that is corrupted or any other data including the observation ...
To Reproduce
On CUDA
Problem! 1065353217
On MPS
System info
The text was updated successfully, but these errors were encountered: