Skip to content

Commit

Permalink
New Profiler callback.
Browse files Browse the repository at this point in the history
  • Loading branch information
frobnitzem committed Dec 11, 2024
1 parent 1c2295c commit b042ac3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nequip/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .nemo_ema import NeMoExponentialMovingAverage
from .write_xyz import TestTimeXYZFileWriter
from .wandb_watch import WandbWatch
from .profiler import Profiler

__all__ = [
SoftAdapt,
Expand All @@ -12,4 +13,5 @@
NeMoExponentialMovingAverage,
TestTimeXYZFileWriter,
WandbWatch,
Profiler,
]
45 changes: 45 additions & 0 deletions nequip/train/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import lightning
from lightning.pytorch.callbacks import Callback
from nequip.data import AtomicDataDict
from nequip.train import NequIPLightningModule

class Profiler(Callback):
"""Proxy class for `TensorBoard Profiler <https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html>`_.
Example usage in config:
::
trainer:
...
callbacks:
- _target_: nequip.train.callbacks.Profiler
trace_output: "./proflog"
Args:
trace_output (str): directory where profile data is stored
"""

def __init__(self, trace_output='proflog'):
super().__init__()
self.prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_output),
record_shapes=True,
profile_memory=True,
with_stack=True
)

def on_train_start(self, trainer, pl_module):
self.prof.start()
def on_train_end(self, trainer, pl_module):
self.prof.stop()
def on_train_batch_start(
self,
trainer: lightning.Trainer,
pl_module: NequIPLightningModule,
batch: AtomicDataDict.Type,
batch_idx: int,
):
""""""
self.prof.step()

0 comments on commit b042ac3

Please sign in to comment.