diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 529931ca0df1..f955cf5ebcad 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -965,8 +965,9 @@ def after_backward_hook(_nonuse_grads): with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack): outputs = function(*inputs_cuda) - for leaf_tensor in leaf_tensors: - leaf_tensor.register_hook(after_backward_hook) + if PROFILE_TIME or SYNCHRONIZE: + for leaf_tensor in leaf_tensors: + leaf_tensor.register_hook(after_backward_hook) see_memory_usage("After running forward on the layer", force=False)