From 2aaa06800a0d13e834333ed832893ab6585dc6b6 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 14 Oct 2024 08:30:02 +0200 Subject: [PATCH] make metrics device configurable through env variable --- docs/user-guide/getting-started/how_to_use.md | 3 ++- src/eva/core/models/modules/module.py | 13 ++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/docs/user-guide/getting-started/how_to_use.md b/docs/user-guide/getting-started/how_to_use.md index 2af3337e..a903e1d4 100644 --- a/docs/user-guide/getting-started/how_to_use.md +++ b/docs/user-guide/getting-started/how_to_use.md @@ -58,4 +58,5 @@ To customize runs, without the need of creating custom config-files, you can ove | `MONITOR_METRIC` | `str` | The metric to monitor for early stopping and final model checkpoint loading | | `MONITOR_METRIC_MODE` | `str` | "min" or "max", depending on the `MONITOR_METRIC` used | | `REPO_OR_DIR` | `str` | GitHub repo with format containing model implementation, e.g. "facebookresearch/dino:main" | -| `TQDM_REFRESH_RATE` | `str` | Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the progress bar. | \ No newline at end of file +| `TQDM_REFRESH_RATE` | `str` | Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the progress bar. | +| `METRICS_DEVICE` | `str` | Specifies the device on which to compute the metrics. If not set, will use the same device as used for training. | \ No newline at end of file diff --git a/src/eva/core/models/modules/module.py b/src/eva/core/models/modules/module.py index d1e2ab64..ba65f679 100644 --- a/src/eva/core/models/modules/module.py +++ b/src/eva/core/models/modules/module.py @@ -1,10 +1,10 @@ """Base model module.""" +import os from typing import Any, Mapping import lightning.pytorch as pl import torch -from lightning.pytorch.strategies.single_device import SingleDeviceStrategy from lightning.pytorch.utilities import memory from lightning.pytorch.utilities.types import STEP_OUTPUT from typing_extensions import override @@ -49,14 +49,9 @@ def default_postprocess(self) -> batch_postprocess.BatchPostProcess: @property def metrics_device(self) -> torch.device: - """Returns the device by which the metrics should be calculated. - - We allocate the metrics to CPU when operating on single device, as - it is much faster, but to GPU when employing multiple ones, as DDP - strategy requires the metrics to be allocated to the module's GPU. - """ - move_to_cpu = isinstance(self.trainer.strategy, SingleDeviceStrategy) - return torch.device("cpu") if move_to_cpu else self.device + """Returns the device by which the metrics should be calculated.""" + device = os.getenv("METRICS_DEVICE", None) + return self.device if device is None else torch.device(device) @override def on_fit_start(self) -> None: