Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Problems with BatchedEnv on accelerated device with single envs on cpu #1864

Open
skandermoalla opened this issue Feb 1, 2024 · 29 comments · Fixed by #1866
Open

[BUG] Problems with BatchedEnv on accelerated device with single envs on cpu #1864

skandermoalla opened this issue Feb 1, 2024 · 29 comments · Fixed by #1866
Assignees
Labels
bug Something isn't working

Comments

@skandermoalla
Copy link
Contributor

skandermoalla commented Feb 1, 2024

Describe the bug

When the batched env device is cuda the step count on the batched env seems completely off from what it should be.
When the batches env device is mps there is a segmentation fault.

I wonder if this is only the step count that is corrupted or any other data including the observation ...

To Reproduce

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv, SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, step_count_key="single_env_step_count", truncated_key="single_env_truncated"))
    return env

def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )

if __name__ == "__main__":
    env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "single_env_step_count"].max().item()
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

On CUDA

Problem!
1065353217

On MPS

python(57380,0x1dd5e5000) malloc: Incorrect checksum for freed object 0x11767f308: probably modified after being freed.
Corrupt value: 0xbd414ea83cfeb221
python(57380,0x1dd5e5000) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    57380 abort      python tests/issue_env_device.py

System info

import torchrl, tensordict, torch, numpy, sys
print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)

2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

2.2.0 0.4.0+eaef29e 0.4.0+01a2216 1.26.3 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:35:25) [Clang 16.0.6 ] darwin
@skandermoalla skandermoalla added the bug Something isn't working label Feb 1, 2024
@skandermoalla
Copy link
Contributor Author

skandermoalla commented Feb 1, 2024

To reproduce the bug on ParallelEnv you need some wizardry:

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "MountainCar-v0"
device = "cuda:0"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key="foo"))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())
    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

@vmoens
Copy link
Contributor

vmoens commented Feb 1, 2024

I can reprod the initial example iif the env is on "cpu" so it's likely just a problem of casting from device to device in serial env
I will check that tomorrow!

@skandermoalla
Copy link
Contributor Author

Nice, thanks! Indeed it's probably device casting gone wrong somewhere as MPS crashed with segfault literally.
Could you reproduce the one with ParallelEn? That's as impactful as the SerialEnv one.

@vmoens vmoens linked a pull request Feb 2, 2024 that will close this issue
@vmoens
Copy link
Contributor

vmoens commented Feb 2, 2024

Can you have a go at 1866 for cpu envs? With me it works on sub-envs on cpu and cuda (even with 100 outer steps)

@vmoens
Copy link
Contributor

vmoens commented Feb 2, 2024

Also I ran the second example with 10K outer iteration but could not reprod the issue (on the branch of the PR but I did not change much from main) so I'm not sure how to address this

VERBOSE=1 python -c """import tqdm
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = 'MountainCar-v0'
device = 'cuda:0'


def build_cpu_single_env():
    env = GymEnv(env_id, device='cpu')
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key='foo'))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=['observation'],
            out_keys=['logits'],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=['logits'],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == '__main__':
    env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())
    for i in tqdm.tqdm(range(10000)):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches['next', 'step_count'].max().item()
        if max_step_count > max_step:
            print('Problem!')
            print(max_step_count)
            print(batches['next', 'step_count'])
            break
    else:
        print('No problem!')
"""

@skandermoalla
Copy link
Contributor Author

skandermoalla commented Feb 2, 2024

VERBOSE=1 python -c """import tqdm                                                                                                                                                                                                                 
from tensordict.nn import TensorDictModule
from torch import nn                                             
from torchrl.envs import (
    EnvCreator,                                                               
    ExplorationType,                                                                   
    StepCounter,         
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10                                                                 
n_env = 4                                                                              
env_id = 'MountainCar-v0'
device = 'cuda:0'

...
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
2024-02-02 13:10:48,627 [torchrl][INFO] resetting implement_for
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
2024-02-02 13:10:48,872 [torchrl][INFO] initiating worker 0
2024-02-02 13:10:48,965 [torchrl][INFO] initiating worker 1
2024-02-02 13:10:48,967 [torchrl][INFO] initiating worker 2
2024-02-02 13:10:48,969 [torchrl][INFO] initiating worker 3
2024-02-02 13:10:51,116 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,142 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,192 [torchrl][INFO] resetting implement_for
2024-02-02 13:10:51,196 [torchrl][INFO] resetting implement_for
  0%|                                                                                                                                                                                                                      | 0/10000 [00:00<?, ?it/s]Problem!
3201372667
tensor([[[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [         1]],

        [[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [3201372667]],

        [[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [         1]],

        [[         1],
         [         2],
         [         3],
         [         4],
         [         5],
         [         6],
         [         7],
         [         8],
         [         9],
         [        10],
         [         1],
         [         1],
         [         1]]], device='cuda:0')
  0%|            
❯ python                                   
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys

>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+01a2216 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

I will try with the 1866 branch and on another cluster.

@skandermoalla
Copy link
Contributor Author

skandermoalla commented Feb 2, 2024

SerialEnv example was solved with #1866. I also tried poking a bit and it was fine.

I will the ParallelEnv one.

@skandermoalla
Copy link
Contributor Author

skandermoalla commented Feb 2, 2024

The problem is now different with ParallelEnv that's why it probably didn't error for you @vmoens.

VERBOSE=1 python -c """import tqdm
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = 'MountainCar-v0'
device = 'cuda:0'


def build_cpu_single_env():
    env = GymEnv(env_id, device='cpu')
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key='foo'))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=['observation'],
            out_keys=['logits'],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=['logits'],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == '__main__':
    env = ParallelEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())
    for i in tqdm.tqdm(range(10)):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches['next', 'step_count'].max().item()
        if max_step_count > max_step:
            print('Problem 1!')
            print(max_step_count)
            print(batches['next', 'step_count'])
            break
        elif max_step_count < max_step:
            print('Problem 2!')
            print(max_step_count)
            print(batches['next', 'step_count'])
            break
    else:
        print('No problem!')
"""
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/usr/local/lib/python3.10/dist-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
2024-02-02 13:44:03,727 [torchrl][INFO] resetting implement_for
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
2024-02-02 13:44:03,981 [torchrl][INFO] initiating worker 0
2024-02-02 13:44:04,041 [torchrl][INFO] initiating worker 1
2024-02-02 13:44:04,043 [torchrl][INFO] initiating worker 2
2024-02-02 13:44:04,045 [torchrl][INFO] initiating worker 3
2024-02-02 13:44:06,233 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,241 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,255 [torchrl][INFO] resetting implement_for
2024-02-02 13:44:06,273 [torchrl][INFO] resetting implement_for
 10%|████████████████████▉                                                                                                                                                                                            | 1/10 [00:00<00:07,  1.23it/s]Problem 2!
1
tensor([[[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]]], device='cuda:0')
 10%|████████████████████▉      
❯ python
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys
>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+eaef29e 0.4.0+1ea3c74 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

I will try on a different cluster.

@skandermoalla
Copy link
Contributor Author

skandermoalla commented Feb 2, 2024

Happens on two different clusters with different CPUs and GPUs (same Docker image though, the NVIDIA NGC PyTorch).

@skandermoalla
Copy link
Contributor Author

On MPS it's not segfault anymore but the original arbitrary number bug:

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = SerialEnv(n_env, EnvCreator(lambda: build_cpu_single_env()), device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        print(max_step_count)
        print(batches["next", "step_count"])
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

gives

/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.num_envs to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.num_envs` for environment variables or `env.get_wrapper_attr('num_envs')` that will search the reminding wrappers.
  logger.warn(
/Users/moalla/mambaforge/envs/torchrl/lib/python3.10/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.reward_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.reward_space` for environment variables or `env.get_wrapper_attr('reward_space')` that will search the reminding wrappers.
  logger.warn(
5157210208
tensor([[[         6],
         [         6],
         [         1],
         [         6],
         [         0],
         [         6],
         [         6],
         [         6],
         [         6],
         [         6],
         [         6],
         [         1],
         [         6]],

        [[         1],
         [         1],
         [         1],
         [         1],
         [5157210208],
         [         1],
         [         2],
         [         1],
         [         1],
         [         2],
         [         1],
         [         2],
         [         1]],

        [[         1],
         [         1],
         [         1],
         [         1],
         [         6],
         [         1],
         [         2],
         [         1],
         [         1],
         [         2],
         [         1],
         [         2],
         [         1]],

        [[         1],
         [         1],
         [         1],
         [         1],
         [         1],
         [         1],
         [         2],
         [         1],
         [         1],
         [         2],
         [         1],
         [         2],
         [         1]]], device='mps:0')
Problem!
5157210208

@skandermoalla
Copy link
Contributor Author

MPS still gives segfault for ParallelEnv.

@vmoens
Copy link
Contributor

vmoens commented Feb 4, 2024

I think it's solved now (for cuda on serial and parallel on the bugfix PR).
I will have a look at mps later!

@skandermoalla
Copy link
Contributor Author

Does this need a specific branch on tensordict?

@vmoens
Copy link
Contributor

vmoens commented Feb 5, 2024

Yeah sorry I'm patching TensorDict let me quickly revert the latest changes which should be part of the another PR

@vmoens
Copy link
Contributor

vmoens commented Feb 5, 2024

I changed it, and tests seem to be passing. If they all do, I'll do a final run of your examples and check the status on MPS. If it all runs smoothly I will consider the PR as good unless you wish to do a proper review of it.

@skandermoalla
Copy link
Contributor Author

I'll poke a bit now and give my feedback soon. So I should test with this branch on TorchRL and main on Tensordict?

@vmoens
Copy link
Contributor

vmoens commented Feb 5, 2024

Yes tensordict main is up to date

@skandermoalla
Copy link
Contributor Author

skandermoalla commented Feb 5, 2024

All good for CUDA! Awesome! (tested some scripts but didn't check the PR code)

❯ python
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torchrl, tensordict, torch, numpy, sys
>>> print(torch.__version__, tensordict.__version__, torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2.2.0a0+81ea7a4 0.4.0+99705db 0.4.0+1f485e9 1.24.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux

(Played with variations of things like here https://github.com/skandermoalla/TorchRL/tree/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests)

@skandermoalla
Copy link
Contributor Author

Testing for MPS.

@skandermoalla
Copy link
Contributor Author

Not yet for MPS.

For SerialEnv (https://github.com/skandermoalla/TorchRL/blob/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests/issue_env_device_serial.py) I have different errors that appear arbitrarily:

python(22437,0x1d9879300) malloc: tiny_free_list_remove_ptr: Internal invariant broken (next ptr of prev): ptr=0x139ced580, prev_next=0x0
python(22437,0x1d9879300) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    22437 abort      python tests/issue_env_device_serial.py
Traceback (most recent call last):
  File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_serial.py", line 47, in <module>
    batches = env.rollout((2 * max_step + 3), policy=policy_module, break_when_any_done=False)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
    tensordicts = self._rollout_nonstop(**kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
    tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
    tensordict_ = self.reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 781, in _reset
    _td = _env.reset(tensordict=tensordict_, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2071, in reset
    return self._reset_proc_data(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 795, in _reset_proc_data
    self._reset_check_done(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2103, in _reset_check_done
    raise RuntimeError(
RuntimeError: Env done entry 'truncated' was (partially) True after reset on specified '_reset' dimensions. This is not allowed.

and

python(22700,0x16dcdb000) malloc: Incorrect checksum for freed object 0x13b396f08: probably modified after being freed.
Corrupt value: 0xbc99eb7cbad79154
python(22700,0x16dcdb000) malloc: *** set a breakpoint in malloc_error_break to debug
[1]    22700 abort      python tests/issue_env_device_serial.py

For ParallelEnv (https://github.com/skandermoalla/TorchRL/blob/34c8abf19fd5a5177a2d5eadd5a5b1f57d51ab6c/tests/issue_env_device_parallel.py) I also have arbitrary errors

Traceback (most recent call last):
  File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 45, in <module>
    policy_module(env.reset())
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2071, in reset
    return self._reset_proc_data(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 795, in _reset_proc_data
    self._reset_check_done(tensordict, tensordict_reset)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2123, in _reset_check_done
    raise RuntimeError(
RuntimeError: The done entry 'truncated' was (partially) True after a call to reset() in env TransformedEnv(
    env=ParallelEnv(
        env=TransformedEnv(
        env=GymEnv(env=MountainCar-v0, batch_size=torch.Size([]), device=cpu),
        transform=Compose(
                StepCounter(keys=[]))), 
        batch_size=torch.Size([4])),
    transform=Compose(
    )).

and

Traceback (most recent call last):
  File "/Users/skander/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 47, in <module>
    batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
    tensordicts = self._rollout_nonstop(**kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
    tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
    tensordict_ = self.reset(tensordict_)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 785, in _reset
    tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/Users/skander/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1298, in _reset
    channel.send(out)
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 557, in reduce_storage
    metadata = storage._share_filename_cpu_()
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/storage.py", line 294, in wrapper
    return fn(self, *args, **kwargs)
  File "/Users/skander/mambaforge/envs/torchrl/lib/python3.10/site-packages/torch/storage.py", line 368, in _share_filename_cpu_
    return super()._share_filename_cpu_(*args, **kwargs)
RuntimeError: _share_filename_: only available on CPU

@vmoens
Copy link
Contributor

vmoens commented Feb 5, 2024

Those seem to be different issues than the CUDA ones, so I think we should go ahead with the PR and make sure these things are working ok with MPS separately!

@vmoens
Copy link
Contributor

vmoens commented Feb 5, 2024

Reopening to keep track of progress with MPS

@vmoens
Copy link
Contributor

vmoens commented Feb 7, 2024

If I avoid updating slices (see this bug) (which happens if you create a single copy of the env) I have no issue with the following code

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    EnvCreator,
    ExplorationType,
    StepCounter,
    TransformedEnv,
    SerialEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "CartPole-v1"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = SerialEnv(n_env, [EnvCreator(build_cpu_single_env) for _ in range(n_env)], device=device)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        print(max_step_count)
        print(batches["next", "step_count"])
        if max_step_count > max_step:
            print("Problem!")
            print(max_step_count)
            break
    else:
        print("No problem!")

Notice the way I create the serial env (I don't use the same EnvCreator)

I also have no issue with ParallelEnv.

@skandermoalla
Copy link
Contributor Author

Not for me. Running the above script gives me the same transient errors I described. Which commits are you using? I'm on the main branches of both TorchRL and TensorDict.
Here's my environment:

torchrl ❯ mamba env export                                                        
name: torchrl
channels:
  - pytorch
  - conda-forge
dependencies:
  - brotli=1.1.0=hb547adb_1
  - brotli-bin=1.1.0=hb547adb_1
  - bzip2=1.0.8=h93a5062_5
  - ca-certificates=2023.11.17=hf0a4a13_0
  - certifi=2023.11.17=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - contourpy=1.2.0=py310hd137fd4_0
  - cycler=0.12.1=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - filelock=3.13.1=pyhd8ed1ab_0
  - fonttools=4.47.2=py310hd125d64_0
  - freetype=2.12.1=hadb7bae_2
  - gmp=6.3.0=h965bd2d_0
  - gmpy2=2.1.2=py310h2e6cad2_1
  - imageio=2.33.1=pyh8c1a49c_0
  - iniconfig=2.0.0=pyhd8ed1ab_0
  - jinja2=3.1.3=pyhd8ed1ab_0
  - kiwisolver=1.4.5=py310h38f39d4_1
  - lcms2=2.16=ha0e7c42_0
  - lerc=4.0.0=h9a09cb3_0
  - libblas=3.9.0=21_osxarm64_openblas
  - libbrotlicommon=1.1.0=hb547adb_1
  - libbrotlidec=1.1.0=hb547adb_1
  - libbrotlienc=1.1.0=hb547adb_1
  - libcblas=3.9.0=21_osxarm64_openblas
  - libcxx=16.0.6=h4653b0c_0
  - libdeflate=1.19=hb547adb_0
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=13_2_0_hd922786_2
  - libgfortran5=13.2.0=hf226fd6_2
  - libjpeg-turbo=3.0.0=hb547adb_1
  - liblapack=3.9.0=21_osxarm64_openblas
  - libopenblas=0.3.26=openmp_h6c19121_0
  - libpng=1.6.39=h76d750c_0
  - libsqlite=3.44.2=h091b4b1_0
  - libtiff=4.6.0=ha8a6c65_2
  - libwebp-base=1.3.2=hb547adb_0
  - libxcb=1.15=hf346824_0
  - libzlib=1.2.13=h53f4e23_5
  - llvm-openmp=17.0.6=hcd81f8e_0
  - markupsafe=2.1.4=py310hd125d64_0
  - matplotlib=3.8.2=py310hb6292c7_0
  - matplotlib-base=3.8.2=py310h9d2df84_0
  - mpc=1.3.1=h91ba8db_0
  - mpfr=4.2.1=h9546428_0
  - mpmath=1.3.0=pyhd8ed1ab_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - ncurses=6.4=h463b476_2
  - networkx=3.2.1=pyhd8ed1ab_0
  - numpy=1.26.3=py310hd45542a_0
  - openjpeg=2.5.0=h4c1507b_3
  - openssl=3.2.1=h0d3ecfb_0
  - packaging=23.2=pyhd8ed1ab_0
  - pcre2=10.42=h26f9a81_0
  - pillow=10.2.0=py310hfae7ebd_0
  - pip=23.3.2=pyhd8ed1ab_0
  - pluggy=1.4.0=pyhd8ed1ab_0
  - pthread-stubs=0.4=h27ca646_1001
  - pyparsing=3.1.1=pyhd8ed1ab_0
  - pytest=8.0.0=pyhd8ed1ab_0
  - python=3.10.13=h2469fbe_1_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.10=4_cp310
  - pytorch=2.2.0=py3.10_0
  - pyyaml=6.0.1=py310h2aa6e3c_1
  - readline=8.2=h92ec313_1
  - setuptools=69.0.3=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - sympy=1.12=pypyh9d50eac_103
  - tk=8.6.13=h5083fa2_1
  - tomli=2.0.1=pyhd8ed1ab_0
  - tornado=6.3.3=py310h2aa6e3c_1
  - typing_extensions=4.9.0=pyha770c72_0
  - tzdata=2023d=h0c530f3_0
  - unicodedata2=15.1.0=py310h2aa6e3c_0
  - wheel=0.42.0=pyhd8ed1ab_0
  - xorg-libxau=1.0.11=hb547adb_0
  - xorg-libxdmcp=1.1.3=h27ca646_0
  - xz=5.2.6=h57fd34a_0
  - yaml=0.2.5=h3422bc3_2
  - zstd=1.5.5=h4f39d0f_0
  - pip:
      - absl-py==2.1.0
      - ale-py==0.8.1
      - annotated-types==0.6.0
      - antlr4-python3-runtime==4.9.3
      - appdirs==1.4.4
      - attrs==23.2.0
      - autorom==0.4.2
      - autorom-accept-rom-license==0.6.1
      - black==24.1.1
      - box2d-py==2.3.5
      - cfgv==3.4.0
      - charset-normalizer==3.3.2
      - click==8.1.7
      - cloudpickle==3.0.0
      - decorator==4.4.2
      - distlib==0.3.8
      - docker-pycreds==0.4.0
      - etils==1.6.0
      - farama-notifications==0.0.4
      - fsspec==2023.12.2
      - gitdb==4.0.11
      - gitpython==3.1.41
      - glfw==2.6.5
      - gymnasium==0.29.1
      - hydra-core==1.3.2
      - identify==2.5.33
      - idna==3.6
      - imageio-ffmpeg==0.4.9
      - importlib-resources==6.1.1
      - joblib==1.3.2
      - jsonref==1.1.0
      - jsonschema==4.21.1
      - jsonschema-specifications==2023.12.1
      - moviepy==1.0.3
      - mujoco==3.1.1
      - mypy-extensions==1.0.0
      - nodeenv==1.8.0
      - omegaconf==2.3.0
      - pathspec==0.12.1
      - platformdirs==4.2.0
      - pre-commit==3.6.0
      - proglog==0.1.10
      - protobuf==4.25.2
      - psutil==5.9.8
      - pydantic==2.6.0
      - pydantic-core==2.16.1
      - pygame==2.5.2
      - pyopengl==3.1.7
      - referencing==0.33.0
      - requests==2.31.0
      - rpds-py==0.17.1
      - scikit-learn==1.4.0
      - scipy==1.12.0
      - sentry-sdk==1.40.0
      - setproctitle==1.3.3
      - shimmy==0.2.1
      - smmap==5.0.1
      - sweeps==0.2.0
      - swig==4.1.1.post1
      - threadpoolctl==3.2.0
      - tqdm==4.66.1
      - urllib3==2.2.0
      - virtualenv==20.25.0
      - wandb==0.16.2
      - zipp==3.17.0

@vmoens
Copy link
Contributor

vmoens commented Feb 12, 2024

Can you check #1900 whenever you have time?

@skandermoalla
Copy link
Contributor Author

Almost solved. It works with Serial and Parallel Env, but somehow breaks when a Transformed env is added on top of the ParallelEnv.

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import (
    ExplorationType,
    StepCounter,
    TransformedEnv,
    ParallelEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 10
n_env = 4
env_id = "MountainCar-v0"
device = "mps"


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step, truncated_key="foo"))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    # Works with both ParallelEnv and SerialEnv
    env = ParallelEnv(n_env, lambda: build_cpu_single_env(), device=device)
    # Breaks when adding a Transformed env on Parallel Env.
    env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    max_step = min(max_step, 200)
    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        # print(max_step_count)
        # print(batches["next", "step_count"])
        if max_step_count > max_step:
            print("Problem 1!")
            print(max_step_count)
            print(batches["next", "step_count"])
            break
        elif max_step_count < max_step:
            print("Problem 2!")
            print(max_step_count)
            print(batches["next", "step_count"])
            break
    else:
        print("No problem!")
Traceback (most recent call last):
  File "/Users/moalla/projects/open-source/TorchRL/tests/issue_env_device_parallel.py", line 48, in <module>
    batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2395, in rollout
    tensordicts = self._rollout_nonstop(**kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2484, in _rollout_nonstop
    tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2554, in step_and_maybe_reset
    tensordict_ = self.reset(tensordict_)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/common.py", line 2056, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/transforms/transforms.py", line 785, in _reset
    tensordict_reset = self.base_env._reset(tensordict=tensordict, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 58, in decorated_fun
    return fun(self, *args, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1315, in _reset
    self.shared_tensordicts[i].apply_(
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 3597, in apply_
    return self.apply(fn, *others, inplace=True, **kwargs)
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/base.py", line 3692, in apply
    return self._apply_nest(
  File "/Users/moalla/projects/open-source/TorchRL/repos/tensordict/tensordict/_td.py", line 712, in _apply_nest
    item_trsf = fn(item, *_others)
  File "/Users/moalla/projects/open-source/TorchRL/repos/rl/torchrl/envs/batched_envs.py", line 1312, in tentative_update
    val.copy_(other)
RuntimeError: destOffset % 4 == 0 INTERNAL ASSERT FAILED at "/Users/runner/work/_temp/anaconda/conda-bld/pytorch_1704987091277/work/aten/src/ATen/native/mps/operations/Copy.mm":107, please report a bug to PyTorch. Unaligned blit request

@skandermoalla
Copy link
Contributor Author

Actually there is another issue with ParallelEnv. The native truncation key is not faithful.

from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.envs import ExplorationType, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import OneHotCategorical, ProbabilisticActor

max_step = 210
n_env = 4
env_id = "MountainCar-v0"
NATIVE_TRUNCATION = 200
device = "mps"
max_step = min(max_step, NATIVE_TRUNCATION)


def build_cpu_single_env():
    env = GymEnv(env_id, device="cpu")
    env = TransformedEnv(env)
    env.append_transform(StepCounter(max_steps=max_step))
    return env


def build_actor(env):
    return ProbabilisticActor(
        module=TensorDictModule(
            nn.LazyLinear(env.action_spec.space.n),
            in_keys=["observation"],
            out_keys=["logits"],
        ),
        spec=env.action_spec,
        distribution_class=OneHotCategorical,
        in_keys=["logits"],
        default_interaction_type=ExplorationType.RANDOM,
    )


if __name__ == "__main__":
    env = ParallelEnv(n_env, lambda: build_cpu_single_env(), device=device)
    # env = TransformedEnv(env)
    policy_module = build_actor(env)
    policy_module.to(device)
    policy_module(env.reset())

    for i in range(10):
        batches = env.rollout((max_step + 3), policy=policy_module, break_when_any_done=False)
        max_step_count = batches["next", "step_count"].max().item()
        if max_step_count > max_step:
            print(max_step_count)
            print(batches["next", "step_count"][:, -5:])
            print("Problem! Got higher than max step count.")
            break
        elif max_step_count < max_step:
            print(max_step_count)
            print(batches["next", "step_count"][:, -5:])
            print("Problem: Got less than max step count!")
            break
    else:
        print(batches["next", "step_count"][:, -5:])
        print("No problem!")
196
tensor([[[194],
         [195],
         [  1],
         [  2],
         [  3]],

        [[194],
         [195],
         [  1],
         [  2],
         [  3]],

        [[195],
         [196],
         [  1],
         [  2],
         [  3]],

        [[195],
         [196],
         [  1],
         [  2],
         [  3]]], device='mps:0')
Problem: Got less than max step count!

@vmoens
Copy link
Contributor

vmoens commented Feb 12, 2024

Ok this will need for me to have access to an mps device then (won't have one for the upcoming 3w I think) :/

@skandermoalla
Copy link
Contributor Author

skandermoalla commented Mar 27, 2024

This is now solved, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants