From 68250d0a1653c7943271dcbf8a639a9eb38368d6 Mon Sep 17 00:00:00 2001 From: Ankit Gola Date: Tue, 5 Mar 2024 11:53:21 +0200 Subject: [PATCH] Remove env var from HPUProfiler --- src/lightning_habana/pytorch/profiler/profiler.py | 9 ++++----- tests/test_pytorch/test_profiler.py | 6 ++++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/lightning_habana/pytorch/profiler/profiler.py b/src/lightning_habana/pytorch/profiler/profiler.py index debacd5f..b0536ea5 100644 --- a/src/lightning_habana/pytorch/profiler/profiler.py +++ b/src/lightning_habana/pytorch/profiler/profiler.py @@ -89,7 +89,10 @@ def __init__( record_module_names: bool = True, **profiler_kwargs: Any, ) -> None: - os.environ["HABANA_PROFILE"] = "1" + assert os.environ.get("HABANA_PROFILE", None) in ( + None, + "profile_api_light", + ), "`HABANA_PROFILE` should not be set when using `HPUProfiler`" super().__init__( dirpath=dirpath, filename=filename, @@ -159,7 +162,3 @@ def on_trace_ready(profiler: _PROFILER) -> None: def summary(self) -> str: return "Summary not supported for HPU Profiler" - - def teardown(self, stage: Optional[str]) -> None: - super().teardown(stage=stage) - os.environ.pop("HABANA_PROFILE", None) diff --git a/tests/test_pytorch/test_profiler.py b/tests/test_pytorch/test_profiler.py index 84c06ec8..b3cf6e02 100644 --- a/tests/test_pytorch/test_profiler.py +++ b/tests/test_pytorch/test_profiler.py @@ -280,3 +280,9 @@ def test_hpu_trace_event_kernel(tmpdir): raise Exception("Could not find event kernel in trace") for event_duration in event_duration_arr: assert event_duration >= 0 + + +def test_hpu_profiler_env(monkeypatch): + monkeypatch.setenv("HABANA_PROFILE", "1") + with pytest.raises(AssertionError, match="`HABANA_PROFILE` should not be set when using `HPUProfiler`"): + HPUProfiler()