Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] MultiaSyncDataCollector crashes when passing Replay Buffer to constructor #2614

Open
3 tasks done
AlexandreBrown opened this issue Nov 28, 2024 · 8 comments
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@AlexandreBrown
Copy link

Describe the bug

Creating an instance of MultiaSyncDataCollector crashes when we pass replay_buffer=my_replay_buffer to its constructor.
The following logs is observed:

python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 8 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

To Reproduce

  1. Create a replay buffer
storage_kwargs = {}
storage_kwargs["max_size"] = 1_000_000
storage_kwargs["device"] = "cpu"

replay_buffer = TensorDictReplayBuffer(
            storage=LazyTensorStorage(**storage_kwargs),
            transform=transform,
        )
  1. Try to create an instance of MultiaSyncDataCollector
MultiaSyncDataCollector(
        create_env_fn=[create_env_fn],
        policy=policy,
        exploration_type=exploration_type,
        replay_buffer=replay_buffer,
        total_frames=cfg["training"]["data_collector"]["total_frames"],
        max_frames_per_traj=cfg["env"]["max_frames_per_traj"],
        frames_per_batch=frames_per_batch,
        init_random_frames=cfg["training"]["data_collector"]["init_random_frames"],
        reset_at_each_iter=False,
        env_device=cfg["env"]["device"],
        storing_device=cfg["storage_device"],
        policy_device=cfg["policy_device"],
        cat_results="stack",
    )
  1. Observe the crash

Expected behavior

No crash.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed: pip
  • Python version: 3.10
  • Versions of any other relevant libraries
pip list
Package                   Version                                                       Editable project location
------------------------- ------------------------------------------------------------- ------------------------------------------
absl-py                   2.1.0
antlr4-python3-runtime    4.9.3
asttokens                 2.4.1
attrs                     24.2.0
av                        13.1.0
certifi                   2024.8.30
charset-normalizer        3.4.0
click                     8.1.7
clip                      1.0
cloudpickle               3.1.0
coloredlogs               15.0.1
comet-ml                  3.47.1
comm                      0.2.2
configobj                 5.0.9
contourpy                 1.3.1
cycler                    0.12.1
Cython                    3.0.11
debugpy                   1.8.9
decorator                 5.1.1
diffusers                 0.31.0
dm_control                1.0.25
dm-env                    1.6
dm-tree                   0.1.8
docker-pycreds            0.4.0
drqv2                     1.0.0                                                        
dulwich                   0.22.6
efficientvit              0.0.0
einops                    0.8.0
etils                     1.10.0
everett                   3.1.0
exceptiongroup            1.2.2
executing                 2.1.0
filelock                  3.16.1
flatbuffers               24.3.25
fonttools                 4.55.0
fsspec                    2024.10.0
ftfy                      6.3.1
gitdb                     4.0.11
GitPython                 3.1.43
glfw                      2.8.0
huggingface-hub           0.26.2
humanfriendly             10.0
hydra-core                1.3.2
idna                      3.10
igraph                    0.11.8
imageio                   2.36.0
importlib_metadata        8.5.0
importlib_resources       6.4.5
ipdb                      0.13.13
ipykernel                 6.29.5
ipython                   8.29.0
jedi                      0.19.2
Jinja2                    3.1.4
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
kiwisolver                1.4.7
labmaze                   1.0.6
lazy_loader               0.4
lightning-utilities       0.11.9
loguru                    0.7.2
lvis                      0.5.3
lxml                      5.3.0
markdown-it-py            3.0.0
MarkupSafe                3.0.2
matplotlib                3.9.2
matplotlib-inline         0.1.7
mdurl                     0.1.2
mpmath                    1.3.0
mujoco                    3.2.5
nest-asyncio              1.6.0
networkx                  3.4.2
numpy                     2.1.3
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
omegaconf                 2.3.0
onnx                      1.17.0
onnxruntime               1.20.1
onnxsim                   0.4.36
opencv-python             4.10.0.84
opencv-python-headless    4.10.0.84
orjson                    3.10.12
packaging                 24.2
pandas                    2.2.3
parso                     0.8.4
pexpect                   4.9.0
pillow                    11.0.0
pip                       24.3.1
platformdirs              4.3.6
prompt_toolkit            3.0.48
protobuf                  5.28.3
psutil                    6.1.0
ptyprocess                0.7.0
pure_eval                 0.2.3
py-cpuinfo                9.0.0
pycocotools               2.0.8
Pygments                  2.18.0
PyOpenGL                  3.1.7
PyOpenGL-accelerate       3.1.7
pyparsing                 3.2.0
python-box                6.1.0
python-dateutil           2.9.0.post0
pytz                      2024.2
PyYAML                    6.0.2
pyzmq                     26.2.0
referencing               0.35.1
regex                     2024.11.6
requests                  2.32.3
requests-toolbelt         1.0.0
rich                      13.9.4
rpds-py                   0.21.0
ruamel.yaml               0.18.6
ruamel.yaml.clib          0.2.12
safetensors               0.4.5
scikit-image              0.24.0
scipy                     1.14.1
seaborn                   0.13.2
OBFUSCATED                    0.0.1                                                         
OBFUSCATED                0.0.1                                                         
segment_anything          1.0
semantic-version          2.10.0
sentry-sdk                2.19.0
setproctitle              1.3.4
setuptools                75.6.0
simplejson                3.19.3
six                       1.16.0
smmap                     5.0.1
stack-data                0.6.3
sympy                     1.13.1
tensordict                0.6.0
texttable                 1.7.0
tifffile                  2024.9.20
timm                      1.0.11
TinyNeuralNetwork         0.1.0.20241024123327+19e5f6dd0f6e391d3c3640cf46d28f47eb76d289
tokenizers                0.20.4
tomli                     2.1.0
torch                     2.5.0
torch-fidelity            0.3.0
torchaudio                2.5.0
torchmetrics              1.6.0
torchprofile              0.0.4
torchrl                   0.6.0
torchvision               0.20.0
tornado                   6.4.2
tqdm                      4.66.5
traitlets                 5.14.3
transformers              4.46.3
triton                    3.1.0
typing_extensions         4.12.2
tzdata                    2024.2
ultralytics               8.3.38
ultralytics-thop          2.0.12
urllib3                   2.2.3
wandb                     0.18.7
wcwidth                   0.2.13
wheel                     0.45.1
wrapt                     1.17.0
wurlitzer                 3.1.1
zipp                      3.21.0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@AlexandreBrown AlexandreBrown added the bug Something isn't working label Nov 28, 2024
@vmoens
Copy link
Contributor

vmoens commented Nov 28, 2024

That error is probably not the one responsible of the crash but a remnant of the original one. Is there anything else in the error stack you can share?

@AlexandreBrown
Copy link
Author

I only see this log + "Killed".
If I do not pass the replay buffer to the data collector then the creation works but it crashes when it tries to collect with the policy.
It then gives an error : queue.Empty

@vmoens
Copy link
Contributor

vmoens commented Nov 28, 2024

Does it run on a single process (SyncDataCollector)?

@AlexandreBrown
Copy link
Author

Yes everything works fine with SyncDataCollector (when passing the RB buffer to the data collector or not both work as expected)

@vmoens
Copy link
Contributor

vmoens commented Nov 28, 2024

Can your policy be serialized? Have you changed the default starting method of multiprocessing (I think "fork" will fail)

@AlexandreBrown
Copy link
Author

AlexandreBrown commented Nov 29, 2024

@vmoens I tried setting the spawn method to fork or spawn and both crash.
But I can no longer reproduce the initial error when I pass my replay buffer.
Now the Data Collector gets created but it crashed during the collection :

Traceback (most recent call last):
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 669, in run
    return mp.Process.run(self, *args, **kwargs)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 2960, in _main_async_collector
    next_data = next(dc_iter)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
    yield from self.iterator()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1035, in iterator
    tensordict_out = self.rollout()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 481, in unpack_rref_and_invoke_function
    return func(self, *args, **kwargs)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 1177, in rollout
    self.replay_buffer.add(self._shuttle)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 1199, in add
    index = super()._add(data)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/replay_buffers.py", line 598, in _add
    index = self._writer.add(data)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/writers.py", line 276, in add
    self._storage.set(index, data)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/_utils.py", line 394, in _lazy_call_fn
    result = self._delazify(self.func_name)(*args, **kwargs)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/data/replay_buffers/storages.py", line 724, in set
    self._storage[cursor] = data
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/_td.py", line 900, in __setitem__
    subtd.set(value_key, item, inplace=True, non_blocking=False)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/base.py", line 4706, in set
    return self._set_tuple(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/_td.py", line 3483, in _set_tuple
    return self._set_str(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/_td.py", line 3409, in _set_str
    inplace = self._convert_inplace(inplace, key)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tensordict/_td.py", line 3394, in _convert_inplace
    raise RuntimeError(_LOCK_ERROR)
RuntimeError: Cannot modify locked TensorDict. For in-place modification, consider using the `set_()` method and make sure the key is present.
Env Data Collection: 0it [00:10, ?it/s]
Error executing job with overrides: ['env=dmc_cartpole_balance', 'algo=sac_pixels']
Traceback (most recent call last):
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
    cli.main()
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 501, in main
    run()
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 351, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
    return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
    _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
  File "/home/mila/b/myuser/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
    exec(code, run_globals)
  File "scripts/train_rl.py", line 131, in <module>
    main()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "scripts/train_rl.py", line 123, in main
    trainer.train()
  File "/home/mila/b/myuser/SegDAC/segdac_dev/src/segdac_dev/trainers/rl_trainer.py", line 41, in train
    for _ in tqdm(self.train_data_collector, "Env Data Collection"):
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 247, in __iter__
    yield from self.iterator()
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 2584, in iterator
    idx, j, out = self._get_from_queue(timeout=10.0)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/site-packages/torchrl/collectors/collectors.py", line 2540, in _get_from_queue
    new_data, j = self.queue_out.get(timeout=timeout)
  File "/home/mila/b/myuser/micromamba/envs/dmc_env/lib/python3.10/multiprocessing/queues.py", line 114, in get
    raise Empty
_queue.Empty
[W1129 09:45:22.315415632 CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

There seems to be 2 error messages, one is :

RuntimeError: Cannot modify locked TensorDict. For in-place modification, consider using the `set_()` method and make sure the key is present.

and the other is :

multiprocessing/queues.py", line 114, in get
    raise Empty
_queue.Empty

@vmoens
Copy link
Contributor

vmoens commented Nov 29, 2024

Interesting!
Can you share more about the RB, what does it look like and how you instantiate it?

@AlexandreBrown
Copy link
Author

AlexandreBrown commented Nov 29, 2024

Sure, here is how I create it :

import torch
from omegaconf import DictConfig
from torchrl.data import ReplayBuffer
from torchrl.data import TensorDictReplayBuffer
from torchrl.data import LazyMemmapStorage
from torchrl.data import LazyTensorStorage
from torchrl.envs.transforms import Compose
from torchrl.envs.transforms import ToTensorImage
from torchrl.envs.transforms import CatFrames
from torchrl.envs.transforms import Transform
from torchrl.envs.transforms import ExcludeTransform


class UnscaleActionTransform(Transform):
    def __init__(self, env_action_scaler):
        super().__init__(in_keys=["action"], out_keys=["action"])
        self.env_action_scaler = env_action_scaler

    def _apply_transform(self, scaled_action):
        return self.env_action_scaler.unscale(scaled_action)


def get_replay_buffer_data_saving_transforms() -> list:
    """
    These are transforms executed when saving data to the replay buffer.
    We want to exclude pixels_transformed because it is in float32 (expensive to store), we can store the uint8 RGB image instead.
    """
    return [
        ExcludeTransform(
            "pixels_transformed", ("next", "pixels_transformed"), inverse=True
        )
    ]


def get_replay_buffer_sample_transforms(cfg: DictConfig, env_action_scaler) -> list:
    """
    These are transforms executed when sampling data from the replay buffer.
    """
    transforms = []
    if "pixels" in cfg["algo"]["training_keys"]:
        transforms.append(
            ToTensorImage(
                from_int=True,
                in_keys=["pixels", ("next", "pixels")],
                out_keys=["pixels_transformed", ("next", "pixels_transformed")],
                shape_tolerant=True,
            )
        )

        frame_stack = cfg["algo"].get("frame_stack")
        if frame_stack is not None:
            transforms.append(
                CatFrames(
                    N=int(frame_stack),
                    dim=-3,
                    in_keys=["pixels_transformed", ("next", "pixels_transformed")],
                    out_keys=["pixels_transformed", ("next", "pixels_transformed")],
                )
            )

    transforms.append(UnscaleActionTransform(env_action_scaler))

    return transforms


def create_replay_buffer(cfg: DictConfig, env_action_scaler) -> ReplayBuffer:
    storage_device = torch.device(cfg["storage_device"])
    capacity = cfg["algo"]["replay_buffer"]["capacity"]

    transforms = []
    transforms.extend(get_replay_buffer_data_saving_transforms())
    transforms.extend(get_replay_buffer_sample_transforms(cfg, env_action_scaler))
    transform = Compose(*transforms)

    storage_kwargs = {}
    storage_kwargs["max_size"] = capacity
    storage_kwargs["device"] = storage_device
    if cfg["env"]["num_workers"] > 1:
        storage_kwargs["ndim"] = 2

    if "cpu" in storage_device.type:
        # LazyMemmapStorage is only supported on CPU
        replay_buffer = TensorDictReplayBuffer(
            storage=LazyMemmapStorage(**storage_kwargs),
            transform=transform,
        )
    else:
        replay_buffer = TensorDictReplayBuffer(
            storage=LazyTensorStorage(**storage_kwargs),
            transform=transform,
        )

    return replay_buffer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants