From c8a43167731d0d328ba584a0631328beac3f8dd8 Mon Sep 17 00:00:00 2001 From: Shicong Huang Date: Tue, 2 Jul 2024 14:48:16 -0700 Subject: [PATCH] Use latest pytorch-lightning version (#926) Previously, we had fixed the version to an earlier release, which became incompatible after pip updated to 24.1. So, we took this opportunity to update it to the latest version. --- dev-requirements.txt | 3 ++- torchx/examples/apps/lightning/profiler.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 8dbcac2ec..144f102e9 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -20,7 +20,8 @@ pyre-extensions pyre-check pytest pytest-cov -pytorch-lightning==1.5.10 +pytorch-lightning==2.3.1 +tensorboard==2.14.0 sagemaker>=2.149.0 torch-model-archiver>=0.4.2 torch==2.2.1 diff --git a/torchx/examples/apps/lightning/profiler.py b/torchx/examples/apps/lightning/profiler.py index 24f2a602e..f5ab0dee4 100644 --- a/torchx/examples/apps/lightning/profiler.py +++ b/torchx/examples/apps/lightning/profiler.py @@ -19,18 +19,19 @@ import time from typing import Dict -from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.profiler.base import BaseProfiler +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.profilers.profiler import Profiler -class SimpleLoggingProfiler(BaseProfiler): + +class SimpleLoggingProfiler(Profiler): """ This profiler records the duration of actions (in seconds) and reports the mean duration of each action to the specified logger. Reported metrics are in the format `duration_`. """ - def __init__(self, logger: LightningLoggerBase) -> None: + def __init__(self, logger: Logger) -> None: super().__init__() self.current_actions: Dict[str, float] = {}