diff --git a/dev-requirements.txt b/dev-requirements.txt index 8dbcac2ec..144f102e9 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -20,7 +20,8 @@ pyre-extensions pyre-check pytest pytest-cov -pytorch-lightning==1.5.10 +pytorch-lightning==2.3.1 +tensorboard==2.14.0 sagemaker>=2.149.0 torch-model-archiver>=0.4.2 torch==2.2.1 diff --git a/torchx/examples/apps/lightning/profiler.py b/torchx/examples/apps/lightning/profiler.py index 24f2a602e..f5ab0dee4 100644 --- a/torchx/examples/apps/lightning/profiler.py +++ b/torchx/examples/apps/lightning/profiler.py @@ -19,18 +19,19 @@ import time from typing import Dict -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.profiler.base import BaseProfiler +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.profilers.profiler import Profiler -class SimpleLoggingProfiler(BaseProfiler): + +class SimpleLoggingProfiler(Profiler): """ This profiler records the duration of actions (in seconds) and reports the mean duration of each action to the specified logger. Reported metrics are in the format `duration_`. """ - def __init__(self, logger: LightningLoggerBase) -> None: + def __init__(self, logger: Logger) -> None: super().__init__() self.current_actions: Dict[str, float] = {}