Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cuda profiler hook #98

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paxml/decode_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,13 @@ def _run_decode_loop(
'Finished decoding input batch %d for %s', step_num, self._name
)

if (
profiler is not None
and step_num - self._task.decode.profiler_capture_step ==
profiler._capture_num_steps
):
profiler.stop_capture_async()

if jax.process_index() == 0:
# Copy the tensor from device memory to ram, since accumulating such
# tensor on devices may cause HBM OOM, when
Expand Down
13 changes: 13 additions & 0 deletions paxml/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
"""Expose functionalities for profiling code."""

from absl import logging
from ctypes import cdll

libcudart = cdll.LoadLibrary('libcudart.so')
def cudaProfilerStart():
libcudart.cudaProfilerStart()
def cudaProfilerStop():
libcudart.cudaProfilerStop()
def cudaDeviceSynchronize():
libcudart.cudaDeviceSynchronize()

class Profiler:
"""Dummy class to capture code profiles.
Expand Down Expand Up @@ -64,8 +72,13 @@ def capture_async(self) -> None:

The duration of the trace corresponds to step_duration_estimate_sec.
"""
cudaProfilerStart()
logging.info('Dummy profiler currently does not capture any trace.')

def stop_capture_async(self) -> None:
cudaDeviceSynchronize()
cudaProfilerStop()

def update_step_moving_mean(self, duration_sec: float):
"""Updates the step duration moving average with a step duration estimate.

Expand Down
8 changes: 8 additions & 0 deletions paxml/programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,14 @@ def run(self, state: TrainState, step: int) -> TrainProgramOutput:

if do_profile and step - self._initial_step < profiler_capture_step:
self._profiler.update_step_moving_mean(train_period.elapsed)

if (
do_profile
and step - self._initial_step ==
profiler_capture_step + self._profiler._capture_num_steps
):
self._profiler.stop_capture_async()

logging.log_first_n(
logging.INFO, '[PAX STATUS]: Writing summaries (attempt).', 5
)
Expand Down