Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training hangs with lightning ddp and cloud dir? #408

Closed
rxqy opened this issue Nov 1, 2024 · 17 comments · Fixed by #468
Closed

training hangs with lightning ddp and cloud dir? #408

rxqy opened this issue Nov 1, 2024 · 17 comments · Fixed by #468
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@rxqy
Copy link

rxqy commented Nov 1, 2024

🐛 Bug

Hi, we are using lightning with litdata on our local machine and aws s3 system. However, training would hang randomly during the very first iterations with ddp and remote cloud directory.

I tried several different configurations, but I'm not sure what I should check next.
GPU / Strategy / FileOn / results
1 / No DDP/ local ssd / OK
1 / No DDP/ remote(s3) / OK
8 / DDP/ local ssd / OK
8 / DDP/ remote(s3) / Stuck.

To Reproduce

I'm following the exact steps on the imagenet demo. And I write a trainer myself here.
Just run python train.py with different CUDA_VISIBLE_DEVICES is enough

Code sample
# train.py
import numpy as np
import lightning as L
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F

class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(nn.Linear(32, 128))

    def training_step(self, batch, batch_idx):
        loss = self.decoder(batch).mean()
        print(self.trainer.local_rank, loss)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


from lightning.data import StreamingDataset, StreamingDataLoader

class ImageNetStreaming(StreamingDataset):
    def __init__(self, ):
        if 1:
            input_dir = "s3:// xxxxx"
            cache_dir = None
        else:
            input_dir = "val"
            cache_dir = None

        max_cache_size = "200GB"
        super().__init__(
            input_dir = input_dir,
            max_cache_size = max_cache_size,
            shuffle = True,
        )

    def __getitem__(self, idx):
        data = super().__getitem__(idx)
        return np.float32(123.)

dataset = ImageNetStreaming()
dataloader = StreamingDataLoader(
    dataset,
    batch_size = 32,
    num_workers = 2,
    pin_memory = True,
    shuffle = True,
    drop_last = True
)

autoencoder = LitAutoEncoder()
trainer = L.Trainer()
trainer.fit(autoencoder, dataloader)

Expected behavior

Training should finish

Additional context

Due to some regulations here we can not put we data or training scirpts on lightning-studio. I'm not sure if something's wrong with our s3 bucket or our our network configuration.
One thing I notice is that even if the training stucks at some iterations(<50), we can still observe large network throughputs on our machine (around 100mb/s), but the local chunk directory( ~/.lightning/chunks) stops growing.

Current environment
  • CUDA:
    • GPU:
      • Tesla V100-SXM2-32GB
      • Tesla V100-SXM2-32GB
    • available: True
    • version: 12.1
  • Lightning:
    • lightning: 2.3.0
    • lightning-utilities: 0.11.1
    • pytorch-lightning: 2.2.1
    • torch: 2.2.1
    • torchaudio: 2.2.1
    • torchmetrics: 1.3.2
    • torchvision: 0.17.1
  • Packages:
    • absl-py: 2.1.0
    • accelerate: 0.30.1
    • aiofiles: 23.2.1
    • aiohttp: 3.9.3
    • aiosignal: 1.3.1
    • angle-emb: 0.3.10
    • annotated-types: 0.7.0
    • anyio: 4.4.0
    • async-timeout: 4.0.3
    • attrs: 23.2.0
    • auto-gptq: 0.7.1
    • av: 12.3.0
    • awscli: 1.32.70
    • backports-datetime-fromisoformat: 2.0.1
    • bitsandbytes: 0.43.1
    • blessed: 1.20.0
    • blinker: 1.7.0
    • boltons: 24.0.0
    • boto3: 1.34.143
    • botocore: 1.34.143
    • braceexpand: 0.1.7
    • brotli: 1.0.9
    • certifi: 2024.2.2
    • charset-normalizer: 2.0.4
    • click: 8.1.7
    • colorama: 0.4.4
    • coloredlogs: 15.0.1
    • contourpy: 1.2.1
    • cos-python-sdk-v5: 1.9.30
    • crcmod: 1.7
    • cycler: 0.12.1
    • datasets: 2.14.6
    • decord: 0.6.0
    • deepspeed: 0.14.0
    • dill: 0.3.7
    • dnspython: 2.6.1
    • docker-pycreds: 0.4.0
    • docstring-parser: 0.16
    • docutils: 0.16
    • einops: 0.7.0
    • email-validator: 2.2.0
    • et-xmlfile: 1.1.0
    • exceptiongroup: 1.2.2
    • faiss-gpu: 1.7.2
    • fastapi: 0.111.1
    • fastapi-cli: 0.0.4
    • ffmpy: 0.3.2
    • filelock: 3.13.1
    • fire: 0.6.0
    • flash-attn: 2.5.7
    • flask: 3.0.3
    • fonttools: 4.51.0
    • frozenlist: 1.4.1
    • fsspec: 2023.10.0
    • gekko: 1.2.1
    • gitdb: 4.0.11
    • gitpython: 3.1.43
    • gmpy2: 2.1.2
    • gpustat: 1.1.1
    • gradio: 4.39.0
    • gradio-client: 1.1.1
    • grpcio: 1.62.1
    • h11: 0.14.0
    • hide-warnings: 0.17
    • hjson: 3.1.0
    • httpcore: 1.0.5
    • httptools: 0.6.1
    • httpx: 0.27.0
    • huggingface-hub: 0.23.4
    • humanfriendly: 10.0
    • idna: 3.4
    • importlib-resources: 6.4.0
    • influxdb: 5.3.2
    • itsdangerous: 2.1.2
    • jinja2: 3.1.3
    • jmespath: 1.0.1
    • joblib: 1.4.0
    • jsonargparse: 4.27.7
    • kafka-python: 2.0.2
    • kiwisolver: 1.4.5
    • lightning: 2.3.0
    • lightning-utilities: 0.11.1
    • litdata: 0.2.29
    • llava: 1.7.0.dev0
    • llmtuner: 0.6.3.dev0
    • m3u8: 4.0.0
    • markdown: 3.6
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.8.4
    • mdurl: 0.1.2
    • media-metric: 0.2.0.10
    • mkl-fft: 1.3.8
    • mkl-random: 1.2.4
    • mkl-service: 2.4.0
    • mmidls: 2.0.3
    • mpmath: 1.3.0
    • msgpack: 1.1.0
    • multidict: 6.0.5
    • multiprocess: 0.70.15
    • networkx: 3.1
    • ninja: 1.11.1.1
    • nssdk: 0.0.1
    • numpy: 1.26.4
    • nvidia-ml-py: 12.535.133
    • onnx: 1.16.0
    • onnxconverter-common: 1.14.0
    • opencv-python-headless: 4.9.0.80
    • openpyxl: 3.1.5
    • optimum: 1.21.1
    • orjson: 3.10.6
    • packaging: 24.0
    • pandas: 2.2.1
    • peft: 0.11.1
    • pillow: 10.2.0
    • pip: 23.3.1
    • platformdirs: 4.2.2
    • ply: 3.11
    • prettytable: 3.10.0
    • protobuf: 3.20.2
    • psutil: 5.9.8
    • py: 1.11.0
    • py-cpuinfo: 9.0.0
    • pyarrow: 15.0.2
    • pyarrow-hotfix: 0.6
    • pyasn1: 0.5.1
    • pycryptodome: 3.20.0
    • pydantic: 2.7.1
    • pydantic-core: 2.18.2
    • pydub: 0.25.1
    • pygments: 2.18.0
    • pynvml: 11.5.0
    • pyparsing: 3.1.2
    • pyrootutils: 1.0.4
    • pysocks: 1.7.1
    • python-dateutil: 2.9.0.post0
    • python-dotenv: 1.0.1
    • python-multipart: 0.0.9
    • pytorch-lightning: 2.2.1
    • pytz: 2024.1
    • pyyaml: 6.0.1
    • redis: 5.0.3
    • regex: 2023.12.25
    • requests: 2.31.0
    • rich: 13.7.1
    • rocketmq-client-python: 2.0.0
    • rouge: 1.0.1
    • rsa: 4.7.2
    • ruff: 0.5.4
    • s3transfer: 0.10.1
    • safetensors: 0.4.2
    • scikit-learn: 1.4.2
    • scipy: 1.13.0
    • seaborn: 0.13.2
    • semantic-version: 2.10.0
    • sentencepiece: 0.2.0
    • sentry-sdk: 2.5.1
    • setproctitle: 1.3.3
    • setuptools: 68.2.2
    • shellingham: 1.5.4
    • shtab: 1.7.1
    • six: 1.16.0
    • smmap: 5.0.1
    • sniffio: 1.3.1
    • sse-starlette: 2.1.2
    • starlette: 0.37.2
    • sympy: 1.12
    • tabulate: 0.9.0
    • taxonomy: 0.10.0
    • tensorboard: 2.16.2
    • tensorboard-data-server: 0.7.2
    • termcolor: 2.4.0
    • threadpoolctl: 3.4.0
    • thrift: 0.20.0
    • thriftpy2: 0.4.20
    • tiktoken: 0.7.0
    • timm: 1.0.3
    • tokenizers: 0.19.1
    • tomlkit: 0.12.0
    • torch: 2.2.1
    • torchaudio: 2.2.1
    • torchmetrics: 1.3.2
    • torchvision: 0.17.1
    • tqdm: 4.66.2
    • transformers: 4.42.4
    • transformers-stream-generator: 0.0.5
    • triton: 2.2.0
    • trl: 0.9.6
    • typer: 0.12.3
    • typeshed-client: 2.5.1
    • typing-extensions: 4.9.0
    • tyro: 0.8.5
    • tzdata: 2024.1
    • urllib3: 2.1.0
    • uvicorn: 0.30.3
    • uvloop: 0.19.0
    • videollama2: 1.0
    • wandb: 0.17.1
    • watchfiles: 0.22.0
    • wcwidth: 0.2.13
    • webdataset: 0.2.93
    • websockets: 11.0.3
    • werkzeug: 3.0.1
    • wheel: 0.41.2
    • xlrd: 2.0.1
    • xmltodict: 0.13.0
    • xxhash: 3.4.1
    • yarl: 1.9.4
  • System:
@rxqy rxqy added bug Something isn't working help wanted Extra attention is needed labels Nov 1, 2024
Copy link

github-actions bot commented Nov 1, 2024

Hi! thanks for your contribution!, great first issue!

@deependujha
Copy link
Collaborator

deependujha commented Nov 3, 2024

Hi @rxqy, thanks for opening the issue. A similar issue is also open for Sagemaker.

We're looking into it and will try to fix it ASAP.

@rxqy
Copy link
Author

rxqy commented Nov 4, 2024

@deependujha, Many thanks. BTW, the above code sometimes gives the FileNotFoundError (and the training loop continues for several iterations and hangs), and sometimes it just hangs. Not sure if it will help or not, but i'm still pasting it here.

Epoch 0:   0%|                                              | 3/10008 [00:04<4:05:42,  0.68it/s, v_num=3]1 tensor(-45.3273, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-45.3273, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                              | 4/10008 [00:05<3:41:55,  0.75it/s, v_num=3]1 tensor(-61.0723, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-61.0723, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                              | 5/10008 [00:05<2:57:38,  0.94it/s, v_num=3]Exception in thread Thread-3:
Traceback (most recent call last):
  File "/data/miniconda3/envs/pl/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/data/miniconda3/envs/pl/lib/python3.10/site-packages/litdata/streaming/reader.py", line 153, in run
    self._maybe_delete_chunks()
  File "/data/miniconda3/envs/pl/lib/python3.10/site-packages/litdata/streaming/reader.py", line 117, in _maybe_delete_chunks
    self._apply_delete(self._chunks_index_to_be_deleted.pop(0))
  File "/data/miniconda3/envs/pl/lib/python3.10/site-packages/litdata/streaming/reader.py", line 91, in _apply_delete
    os.remove(locak_chunk_path)
FileNotFoundError: [Errno 2] No such file or directory: '/data/.lightning/chunks/b515aeecb3a09f152677fce166405b10/1730182031.4955728/chunk-4-1309.bin.lock'
0 tensor(-76.8173, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                              | 6/10008 [00:05<2:45:20,  1.01it/s, v_num=3]1 tensor(-76.8173, device='cuda:1', grad_fn=<MeanBackward0>)
1 tensor(-92.5623, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-92.5623, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                              | 7/10008 [00:08<3:18:18,  0.84it/s, v_num=3]1 tensor(-108.3073, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-108.3073, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                              | 8/10008 [00:08<2:53:34,  0.96it/s, v_num=3]1 tensor(-124.0523, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-124.0523, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                              | 9/10008 [00:08<2:34:20,  1.08it/s, v_num=3]1 tensor(-139.7973, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-139.7973, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                             | 10/10008 [00:09<2:38:37,  1.05it/s, v_num=3]1 tensor(-155.5423, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-155.5423, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                             | 11/10008 [00:09<2:30:48,  1.10it/s, v_num=3]0 tensor(-171.2873, device='cuda:0', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                             | 12/10008 [00:09<2:18:17,  1.20it/s, v_num=3]1 tensor(-171.2873, device='cuda:1', grad_fn=<MeanBackward0>)
0 tensor(-187.0323, device='cuda:0', grad_fn=<MeanBackward0>)
1 tensor(-187.0323, device='cuda:1', grad_fn=<MeanBackward0>)
Epoch 0:   0%|                                             | 13/10008 [00:10<2:13:49,  1.24it/s, v_num=3]1 tensor(-202.7773, device='cuda:1', grad_fn=<MeanBackward0>)

@tchaton
Copy link
Collaborator

tchaton commented Nov 22, 2024

Hey @rxqy. Could you try to add a try / except around it in LitData and let us know if it helps ? There is a race condition on deleting the file but it is file to catch and skip it.

If it helps, would you mind making a PR with the fix ?

@rxqy
Copy link
Author

rxqy commented Nov 23, 2024

Hi @tchaton. I think this should be on the lightning side?
Probablilty related: #411 . I'm observing a quite similar behavior, e.g., the exact same dataloader length with different num of gpus with lit trainer and lit ddp strategy.
From the issue above, we should init our streaming dataloader after ddp init, but with lightning's trainer, we are initializing dataloader before ddp init.

I wrote a pytorch ddp demo. With the exact same dataloader, we can finish training quite smoothly.

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

import numpy as np
import lightning
from lightning.data import StreamingDataset, StreamingDataLoader
from train import ImageNetStreaming
from tqdm import tqdm

def demo_basic():
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")
    # create model and move it to GPU with id rank
    device_id = rank % torch.cuda.device_count()
    model = ToyModel().to(device_id)
    ddp_model = DDP(model, device_ids=[device_id])
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    dataset = ImageNetStreaming()
    dataloader = StreamingDataLoader(
        dataset,
        batch_size = 32,
        num_workers = 2,
        pin_memory = True,
        shuffle = True,
        drop_last = True
    )
    print(len(dataset), len(dataloader))

    for _, data in enumerate(tqdm(dataloader, disable=(rank!=0))):
        optimizer.zero_grad()
        outputs = ddp_model(data)
        loss = outputs.mean()
        loss.backward()
        optimizer.step()
        print(rank, loss)


    dist.destroy_process_group()
    print(f"Finished running basic DDP example on rank {rank}.")

if __name__ == "__main__":
    demo_basic()

@rxqy
Copy link
Author

rxqy commented Nov 23, 2024

Just to clarify, I made no code change to my litdata or lightning package. And we are not using fabric in our trainer.
You can launch the above pt ddp script with
torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29400 ddp_demo.py.
Not sure what's the correct way to init streaming dataloader with lightning trainer.

@tchaton
Copy link
Collaborator

tchaton commented Nov 23, 2024

You should instantiate the dataset in the setuo hook of the datamodule or directly within the dataloader hook

@rxqy
Copy link
Author

rxqy commented Nov 25, 2024

Hi @tchaton , what's the dataloader hook you are mentioning here?
I'm looking at the datamodule's doc here, it seems that only the setup method (for spliting train/test datasets) is called on every device, but not the train_dataloader method.

@schopra8
Copy link

schopra8 commented Dec 31, 2024

@tchaton We're running into an identical issue. We also are getting:

GPU / Strategy / FileOn / results
1 / No DDP/ local ssd / OK
1 / No DDP/ remote(s3) / OK
8 / DDP/ local ssd / OK
8 / DDP/ remote(s3) / Stuck.

We are initiating our dataset in setup of our CustomDataModule:

class TCDataModule(LightningDataModule):
    """
    Text Conditional Data Module wraps training streaming datasets.
    """

    def __init__(self,
                 train_dataset_configs: List[CustomStreamingDatasetConfig]):
        super().__init__()
        self.train_dataset_configs = train_dataset_configs

    def setup(self, stage=None):
        """
        Setup the module, by creating the underlying training datasets.
        """
        if stage in ['fit', None]:
            # Training
            train_datasets = []
            for config in self.train_dataset_configs:
                train_datasets.append(CustomStreamingDataset(config=config))
            self.train_dataset = CombinedStreamingDataset(train_datasets, iterate_over_all=True)

            # Validation
            ...

When using DDP with remote data, we get 1 iteration/second in terms of speed. After the 1st epoch, 15-16 steps run forward at 1 iteration/second and then training stalls for 3-5 minutes (no GPU utilization). Any ideas what the underlying issue could be?

@tchaton
Copy link
Collaborator

tchaton commented Jan 23, 2025

Hey @schopra8 @rxqy Could you try out this branch: #456. I am trying to add a fallback mechanism to force download the chunk if it got deleted. Not ideal but should unblock DDP training.

@JackUrb
Copy link
Contributor

JackUrb commented Jan 27, 2025

@tchaton a +1 to this, the issue we noted is that one of the dataloader worker threads ends up caught at exactly the loop you've added a timeout to in #456, however the zstd.bin exists, but the .bin does not, and it never seems to get unpacked (or was unpacked, then removed). I imagine #456 will cover us too, though agreed it doesn't hit the underlying issue.

For what it's worth, I was having a hard time tracking this bug down last week, and found it persists even when all of the StreamingDataloaders and underlying StreamingDatasets are initialized with per-gpu unique cache_dirs, implying that it's within a rank that the unzip fails to occur as expected.

@dawood95
Copy link

Training on 4 nodes, 8 gpu each, with DDP. Still see this problem with #456 in. Currently have training running by

  1. Not using compression (@tchaton I believe there needs to be a timeout loop there as well, attest temporarily)
  2. Using a really large cache size
  3. Using a high max_pre_download number
  4. Decreasing number of workers (although, not sure this is needed)
  5. Using per training worker cache directory

@tchaton
Copy link
Collaborator

tchaton commented Feb 10, 2025

Hey @JackUrb. Thanks for the info. Yes, we need to spend more time finding the source of this bug.

I wonder if you could add more prints to see if you learn more. Happy to pair debug with you

@JackUrb
Copy link
Contributor

JackUrb commented Feb 10, 2025

I've also got this running stably by introducing a count file for the shards. Running a 4x8 job at the moment that appears to be stable with:

  1. no compression
  2. 32GB cache size (shared across all workers/nodes), max predownload size of 5 per StreamingDataset.
  3. StreamingDataLoader prefetch count of 10 and 8 workers.
  4. A change to Downloader and PrepareChunksThread (reader) that stores a count of accesses per-chunkfile in a separate locked .cnt file, and only allows delete when this is 0.

At the moment though, many don't cleanup as I end up with many counts that never go back to 0.

Once I get an 8x8 job stable under this setup, I'll try with compression as well, and if that looks good I'll open a PR for my count-lock change.

@tchaton
Copy link
Collaborator

tchaton commented Feb 11, 2025

Hey @JackUrb.

Thanks great to hear ! Feel free to make a draft PR already, so I can have a look and maybe investigate on my end too.

Best regards,
Thomas.

@dawood95
Copy link

eventually ran into the issue which manifests as a slowdown, presumably waiting for download constantly

@tchaton
Copy link
Collaborator

tchaton commented Feb 13, 2025

Hey @dawood95 @rxqy @schopra8 Can you try this branch: #468. We think this will fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants