Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] check_env_specs + PixelRenderTransform does not tolerate "cuda" device #2236

Closed
3 tasks done
N00bcak opened this issue Jun 16, 2024 · 0 comments · Fixed by #2237
Closed
3 tasks done

[BUG] check_env_specs + PixelRenderTransform does not tolerate "cuda" device #2236

N00bcak opened this issue Jun 16, 2024 · 0 comments · Fixed by #2237
Assignees
Labels
bug Something isn't working

Comments

@N00bcak
Copy link
Contributor

N00bcak commented Jun 16, 2024

Describe the bug

Running check_env_specs on a TransformedEnv which contains the PixelRenderTransform fails.

To Reproduce

# BEFORE THE PROGRAM EVEN RUNS, FORCE THE START METHOD TO BE 'SPAWN'
from torch import multiprocessing as mp
mp.set_start_method("spawn", force = True)

from copy import deepcopy
import tqdm
import numpy as np

import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.envs import check_env_specs, PettingZooEnv, ParallelEnv
from torchrl.modules import ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.collectors import SyncDataCollector
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform

EPS = 1e-7
    
# Main Function
if __name__ == "__main__":    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 1
    EXPLORATION_STEPS = 30000
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 512
    LR = 3e-4
    UPDATE_STEPS_PER_EXPLORATION = 1500
    WARMUP_STEPS = int(3e5)
    TRAIN_TIMESTEPS = int(1e7)
    EVAL_INTERVAL = 10
    EVAL_EPISODES = 20

    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    def env_fn(mode, parallel = True, rew_scale = True):

        if rew_scale:
            terminate_scale = -3.0
            forward_scale = 2.5
            fall_scale = -3.0
        else:
            # Use the defaults from PZ
            terminate_scale, forward_scale, fall_scale = -100.0, 1.0, -10.0

        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = terminate_scale,
                                    forward_reward = forward_scale,
                                    fall_reward = fall_scale,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = DEVICE
                                )

        env = base_env_fn # noqa: E731

        def env_with_transforms():
            init_env = env()
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                        )
                                    )
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()


    def create_eval_env(tag = "rendered"):
        
        eval_env = env_fn("rgb_array", parallel = False, rew_scale = False)()
        video_recorder = VideoRecorder(
                                        CSVLogger("multiwalker-toy-test", video_format = "mp4"), 
                                        tag = tag, 
                                        in_keys = ["pixels_record"]
                                    )
        
        # Call the parent's render function
        eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
        eval_env.append_transform(video_recorder)

        if eval_env.is_closed:
            eval_env.start()
        return eval_env

    check_env_specs(train_env)
    eval_env = create_eval_env()
    check_env_specs(eval_env)

    train_env.close()
File "/mnt/c/Users/N00bcak/Desktop/programming/drones_go_brr/scripts/torchrl_cuda_hangs.py", line 115, in <module>
    check_env_specs(eval_env)
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/utils.py", line 728, in check_env_specs
    fake_tensordict = env.fake_tensordict()
                      ^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 2922, in fake_tensordict
    observation_spec = self.observation_spec
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/common.py", line 1303, in observation_spec
    observation_spec = self.output_spec["full_observation_spec"]
                       ^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 748, in output_spec
    output_spec = self.transform.transform_output_spec(output_spec)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 1104, in transform_output_spec
    output_spec = t.transform_output_spec(output_spec)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/envs/transforms/transforms.py", line 376, in transform_output_spec
    output_spec["full_observation_spec"] = self.transform_observation_spec(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/record/recorder.py", line 501, in transform_observation_spec
    observation_spec[self.out_keys[0]] = spec
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/n00bcak/venvs/torchrl-3.11/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 3783, in __setitem__
    raise RuntimeError(
RuntimeError: Setting a new attribute (pixels_record) on another device (cuda:0 against cuda). All devices of CompositeSpec must match.

Expected behavior

check_env_specs succeeds and program terminates.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
>>> import torchrl, numpy, sys
>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.4.0 1.26.4 3.11.9 (main, Jun  5 2024, 10:27:27) [GCC 12.2.0] linux

Reason and Possible fixes

A strict check appears to be conducted on the device strings, which results in the error.

For consistency with PyTorch in general, can consider substituting "cuda" with f"cuda:{torch.cuda.current_device()}"

Depending on availability of current_device() on other devices, can consider implementing checks for those too.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@N00bcak N00bcak added the bug Something isn't working label Jun 16, 2024
@vmoens vmoens linked a pull request Jun 18, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants