Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AUC does not support compute_on_cpu in DDP. #1088

Closed
hslee-lunit opened this issue Jun 15, 2022 · 2 comments
Closed

AUC does not support compute_on_cpu in DDP. #1088

hslee-lunit opened this issue Jun 15, 2022 · 2 comments
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed v0.8.x
Milestone

Comments

@hslee-lunit
Copy link

🐛 Bug

To Reproduce

Steps to reproduce the behavior...

AUC calculation in DDP is not working when compute_on_cpu is True.
I guess the reason is that x and y are saved in CPU, but when computing the metric, all_gather is triggered.
Accuracy works fine, but AUC does not work.

Error message
File "/workspace/pytorch-lightning/pl_examples/basic_examples/mnist_examples/image_classifier_5_lightning_datamodule.py", line 51
, in training_step
  self.advance(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 268, in advance
  self.log("train_auc", self.train_auc.compute(), prog_bar=True)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 435, in wrapped_func
  self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
  with self.sync_context(
File "/opt/conda/lib/python3.8/contextlib.py", line 113, in __enter__
  self.advance(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 208, in advance
  return next(self.gen)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 406, in sync_context
  batch_output = self.batch_loop.run(batch, batch_idx)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
  self.sync(
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 358, in sync
  self.advance(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
  outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
  self._sync_dist(dist_sync_fn, process_group=process_group)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 287, in _sync_dist
  self.advance(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 203, in advance
  output_dict = apply_to_collection(
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 184, in apply_to_collection
  result = self._run_optimization(
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 256, in _run_optimization
  return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 184, in <dictcomp>
  self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 369, in _optimizer_step
  return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 190, in apply_to_collection
  return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 190, in <listcomp>
  self.trainer._call_lightning_module_hook(
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1593, in _call_lightning_module_hook
  return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 180, in apply_to_collection
  return function(data, *args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/distributed.py", line 131, in gather_all_tensors
  torch.distributed.all_gather(local_sizes, local_size, group=group)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1865, in all_gather
  output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1644, in optimizer_step
  work = group.allgather([tensor_list], [tensor])
RuntimeError: Tensors must be CUDA and dense
  optimizer.step(closure=optimizer_closure)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
  step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 278, in optimizer_step
  optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 193, in optimizer_step
  return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 155, in optimizer_step
  return optimizer.step(closure=closure, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
  return wrapped(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py", line 89, in wrapper
  return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
  return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/optim/adadelta.py", line 50, in step
  loss = closure()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 140, in _wrap_closure
  closure_result = closure()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 148, in __call__
  self._result = self.closure(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 134, in closure
  step_output = self._step_fn()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 427, in _training_step
  training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1763, in _call_strategy_hook
  output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 341, in training_step
  return self.model(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
  result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
  output = self.module(*inputs[0], **kwargs[0])
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
  result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 82, in forward
  output = self.module.training_step(*inputs, **kwargs)
File "mnist_examples/image_classifier_5_lightning_datamodule.py", line 51, in training_step
  self.log("train_auc", self.train_auc.compute(), prog_bar=True)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 435, in wrapped_func
  with self.sync_context(
File "/opt/conda/lib/python3.8/contextlib.py", line 113, in __enter__
  return next(self.gen)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 406, in sync_context
  self.sync(
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 358, in sync
  self._sync_dist(dist_sync_fn, process_group=process_group)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/metric.py", line 287, in _sync_dist
  output_dict = apply_to_collection(
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 184, in apply_to_collection
  return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 184, in <dictcomp>
  return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 190, in apply_to_collection
  return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 190, in <listcomp>
  return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/data.py", line 180, in apply_to_collection
  return function(data, *args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torchmetrics/utilities/distributed.py", line 131, in gather_all_tensors
  torch.distributed.all_gather(local_sizes, local_size, group=group)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1865, in all_gather
  work = group.allgather([tensor_list], [tensor])
RuntimeError: Tensors must be CUDA and dense

Code sample

To reproduce this issue, I used the pytorch-lightning example code - mnist_examples/image_classifier_5_lightning_datamodule.py

I marked the edited code with # ADDED!

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple MNIST image classifier example with LightningModule and LightningDataModule.

To run: python image_classifier_5_lightning_datamodule.py --trainer.max_epochs=50
"""
import torch
import torchvision.transforms as T
from torch.nn import functional as F
from torchmetrics import Accuracy, AUC

from pl_examples import cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNIST
from pl_examples.basic_examples.mnist_examples.image_classifier_1_pytorch import Net
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.utilities.cli import LightningCLI


class ImageClassifier(LightningModule):
    def __init__(self, model, lr=1.0, gamma=0.7, batch_size=32):
        super().__init__()
        self.save_hyperparameters(ignore="model")
        self.model = model or Net()
        self.train_auc = AUC(reorder=True, compute_on_cpu=True)  # ADDED!
        self.test_acc = Accuracy()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())

        temp_logits = torch.rand(10, dtype=torch.float32).cuda()  # ADDED!
        temp_y = torch.randint(0, 1, (10,), dtype=torch.float32).cuda()  # ADDED!
        self.train_auc(temp_logits, temp_y)  # ADDED!
        self.log("train_auc", self.train_auc.compute(), prog_bar=True)  # ADDED!
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y.long())
        self.test_acc(logits, y)
        self.log("test_acc", self.test_acc)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr)
        return [optimizer], [torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.hparams.gamma)]


class MNISTDataModule(LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.save_hyperparameters()

    @property
    def transform(self):
        return T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self) -> None:
        MNIST("./data", download=True)

    def train_dataloader(self):
        train_dataset = MNIST("./data", train=True, download=False, transform=self.transform)
        return torch.utils.data.DataLoader(train_dataset, batch_size=self.hparams.batch_size)

    def test_dataloader(self):
        test_dataset = MNIST("./data", train=False, download=False, transform=self.transform)
        return torch.utils.data.DataLoader(test_dataset, batch_size=self.hparams.batch_size)


def cli_main():
    # The LightningCLI removes all the boilerplate associated with arguments parsing. This is purely optional.
    cli = LightningCLI(
        ImageClassifier, MNISTDataModule, seed_everything_default=42, save_config_overwrite=True, run=False
    )
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)


if __name__ == "__main__":
    cli_lightning_logo()
    cli_main()

and run with DDP

python mnist_examples/image_classifier_5_lightning_datamodule.py --trainer.accelerator 'gpu' --trainer.devices 2 --trainer.strategy 'ddp'

Then you can get the same error message above

Expected behavior

No error message like Accuracy calculation

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.8.2
  • Python & PyTorch Version (e.g., 1.0): Python v3.8.8 & Pytorch v1.8.1
  • Any other relevant information such as OS (e.g., Linux): Linux

Additional context

@hslee-lunit hslee-lunit added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 15, 2022
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@hslee-lunit hslee-lunit changed the title Some metrics do not support compute_on_cpu in DDP. AUC does not support compute_on_cpu in DDP. Jun 15, 2022
@Borda Borda added this to the v0.10 milestone Jul 27, 2022
@SkafteNicki
Copy link
Member

As part of the classification refactor described in #1001 and implemented in #1189 it was decided that auc will no longer be available as an metric in itself but instead we only offer a functional implementation in ``torchmetrics.utilities.compute. The reason being that auc` is not really a metric but a tool to calculate the area under any given curve.

I know that this is not a reason for the error explained in this issue, but I am going to close the issue as this has to do with the modular implementation which is being completely removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v0.8.x
Projects
None yet
Development

No branches or pull requests

3 participants