diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 2fc2d8bccb6..fb534a35132 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -136,6 +136,7 @@ def update(data): data = offline_buffer.sample() with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() loss_vals = update(data.to(model_device)) scheduler.step()