Skip to content

Commit

Permalink
make metrics device configurable through env variable
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Oct 14, 2024
1 parent bdc072f commit 2aaa068
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
3 changes: 2 additions & 1 deletion docs/user-guide/getting-started/how_to_use.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ To customize runs, without the need of creating custom config-files, you can ove
| `MONITOR_METRIC` | `str` | The metric to monitor for early stopping and final model checkpoint loading |
| `MONITOR_METRIC_MODE` | `str` | "min" or "max", depending on the `MONITOR_METRIC` used |
| `REPO_OR_DIR` | `str` | GitHub repo with format containing model implementation, e.g. "facebookresearch/dino:main" |
| `TQDM_REFRESH_RATE` | `str` | Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the progress bar. |
| `TQDM_REFRESH_RATE` | `str` | Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the progress bar. |
| `METRICS_DEVICE` | `str` | Specifies the device on which to compute the metrics. If not set, will use the same device as used for training. |
13 changes: 4 additions & 9 deletions src/eva/core/models/modules/module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Base model module."""

import os
from typing import Any, Mapping

import lightning.pytorch as pl
import torch
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from lightning.pytorch.utilities import memory
from lightning.pytorch.utilities.types import STEP_OUTPUT
from typing_extensions import override
Expand Down Expand Up @@ -49,14 +49,9 @@ def default_postprocess(self) -> batch_postprocess.BatchPostProcess:

@property
def metrics_device(self) -> torch.device:
"""Returns the device by which the metrics should be calculated.
We allocate the metrics to CPU when operating on single device, as
it is much faster, but to GPU when employing multiple ones, as DDP
strategy requires the metrics to be allocated to the module's GPU.
"""
move_to_cpu = isinstance(self.trainer.strategy, SingleDeviceStrategy)
return torch.device("cpu") if move_to_cpu else self.device
"""Returns the device by which the metrics should be calculated."""
device = os.getenv("METRICS_DEVICE", None)
return self.device if device is None else torch.device(device)

@override
def on_fit_start(self) -> None:
Expand Down

0 comments on commit 2aaa068

Please sign in to comment.