Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP mode stuck with multiple GPUs when calling MAP metric #626

Closed
tkupek opened this issue Nov 12, 2021 · 4 comments · Fixed by #624
Closed

DDP mode stuck with multiple GPUs when calling MAP metric #626

tkupek opened this issue Nov 12, 2021 · 4 comments · Fixed by #624
Assignees
Labels
bug / fix Something isn't working topic: Image
Milestone

Comments

@tkupek
Copy link
Contributor

tkupek commented Nov 12, 2021

🐛 Bug

I have a detection model training on multiple GPUs in DDP mode with the MAP metric.
There is a weird behavior that in some iterations, the validation gets stuck when calling the compute() method. All GPUs are now working on 100% load and nothing is happening for hours.

To Reproduce

I was able to reproduce the issue with the boring model.
⚠️ Make sure to run this with at least two parallel GPUs to trigger the issue.

import torch
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import Dataset, DataLoader
from torchmetrics import MAP


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


train = RandomDataset(32, 100)
train = DataLoader(train, batch_size=32)

val = RandomDataset(32, 100)
val = DataLoader(val, batch_size=32)

# mockups for MAP compatible data
mock_preds = [
    dict(
        boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
        scores=torch.Tensor([0.536]),
        labels=torch.IntTensor([0]),
    )
]
mock_target = [
    dict(
        boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
        labels=torch.IntTensor([0]),
    )
]


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.val_map = MAP(class_metrics=True)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)

        # ignore real outputs and add mockup preds to metric
        preds = []
        target = []
        for x in mock_preds:
            preds.append({
                'boxes': x['boxes'].to(self.device),
                'labels': x['labels'].to(self.device),
                'scores': x['scores'].to(self.device)
            })

        for x in mock_target:
            target.append({
                'boxes': x['boxes'].to(self.device),
                'labels': x['labels'].to(self.device)
            })

        self.val_map.update(preds=preds, target=target)
        return {"x": loss}

    def on_validation_epoch_start(self) -> None:
        self.val_map.reset()

    def on_validation_epoch_end(self) -> None:
        if self.trainer.global_step != 0 and self.trainer.is_global_zero:
            print(
                f"Running val metric on {len(self.val_map.groundtruth_boxes)} samples"
            )
            result = self.val_map.compute() # GPUs get stuck here
            print(result)

    def configure_optimizers(self):
        return [torch.optim.SGD(self.layer.parameters(), lr=0.1)]


model = BoringModel()
trainer = Trainer(
    max_epochs=1,
    strategy='ddp',
    gpus=2
)

trainer.fit(model, train, val)

Expected behavior

Metric gets calculated as with a single GPU:
{'map': tensor([0.6000]), 'map_50': tensor([1.]), 'map_75': tensor([1.]), 'map_small': tensor([-1.]), 'map_medium': tensor([-1.]), 'map_large': tensor([0.6000]), 'mar_1': tensor([0.6000]), 'mar_10': tensor([0.6000]), 'mar_100': tensor([0.6000]), 'mar_small': tensor([-1.]), 'mar_medium': tensor([-1.]), 'mar_large': tensor([0.6000]), 'map_per_class': tensor([0.6000]), 'mar_100_per_class': tensor([0.6000])}

Environment

  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - NVIDIA GeForce RTX 2080 Ti
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.21.4
    - pyTorch_debug: False
    - pyTorch_version: 1.10.0+cu102
    - pytorch-lightning: 1.5.0
    - tqdm: 4.62.3
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.0
    - version: # 92-Ubuntu SMP Fri Feb 28 11:09:48 UTC 2020

Additional context

If it helps, I am available for a live debugging session.

@SkafteNicki
Copy link
Member

The problem is that you are calling metric.compute from only one process.
Essentially, this does not work

if self.trainer.is_global_zero:
    result = self.val_map.compute()

while this work

result = self.val_map.compute()

the reason is that compute will call an dist.barrier at some point trying to sync all processes. If compute only gets called on process 0 then it will wait indefinitely on the other processes trying to reach the barrier (which they never will).

@tkupek
Copy link
Contributor Author

tkupek commented Nov 15, 2021

Thanks for the hint. I fixed the obvious error, but can still reproduce the error by an empty prediction Tensor on one of the GPUs:

import random

import torch
import torch.distributed
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import Dataset, DataLoader
from torchmetrics import MAP

NUM_GPUS = 2
BATCH_SIZE = 32


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


train = RandomDataset(32, 100 * NUM_GPUS * BATCH_SIZE)
train = DataLoader(train, batch_size=32)

val = RandomDataset(32, 100 * NUM_GPUS * BATCH_SIZE)
val = DataLoader(val, batch_size=BATCH_SIZE)

# mockups for MAP compatible data
mock_preds = [
    dict(
        boxes=torch.Tensor([]),
        scores=torch.Tensor([]),
        labels=torch.IntTensor([]),
    ),
    dict(
        boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
        scores=torch.Tensor([0.536]),
        labels=torch.IntTensor([0]),
    )
]
mock_target = [
    dict(
        boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
        labels=torch.IntTensor([0]),
    )
]


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.val_map = MAP(class_metrics=True, dist_sync_on_step=True)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)

        # ignore real outputs and add mockup preds to metric
        preds = []
        target = []

        for n in range(batch.size(0)):
            x = mock_preds[torch.distributed.get_rank()]
            preds.append({
                'boxes': x['boxes'].to(self.device),
                'labels': x['labels'].to(self.device),
                'scores': x['scores'].to(self.device)
            })

            x = mock_target[0]
            target.append({
                'boxes': x['boxes'].to(self.device),
                'labels': x['labels'].to(self.device)
            })

        self.val_map.update(preds=preds, target=target)
        return {"x": loss}

    def on_validation_epoch_start(self) -> None:
        self.val_map.reset()

    def on_validation_epoch_end(self) -> None:
        if self.trainer.global_step != 0:
            print(
                f"Running val metric on {len(self.val_map.groundtruth_boxes)} samples"
            )
            result = self.val_map.compute()  # GPUs get stuck here
            print(result)

    def configure_optimizers(self):
        return [torch.optim.SGD(self.layer.parameters(), lr=0.1)]


model = BoringModel()
trainer = Trainer(
    max_epochs=10,
    strategy='ddp',
    gpus=2
)

trainer.fit(model, train, val)

What is the problem here? Is it the empty tensor that makes DDP fail? Do we need to filter them out in the .update method of the MAP metric?

@SkafteNicki SkafteNicki reopened this Nov 16, 2021
@SkafteNicki SkafteNicki transferred this issue from Lightning-AI/pytorch-lightning Nov 16, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@shawnye1994
Copy link

The problem is that you are calling metric.compute from only one process. Essentially, this does not work

if self.trainer.is_global_zero:
    result = self.val_map.compute()

while this work

result = self.val_map.compute()

the reason is that compute will call an dist.barrier at some point trying to sync all processes. If compute only gets called on process 0 then it will wait indefinitely on the other processes trying to reach the barrier (which they never will).

For anyone who wants to perform metrics computation only on process 0, you can create your Metric with sync_dist=False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working topic: Image
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants