Skip to content

Commit

Permalink
Non-reentrant checkpointing hook fix (microsoft#5781)
Browse files Browse the repository at this point in the history
This PR adds an extra condition to attach backward pass hooks to leaf
nodes only if Synchronisation or Profiling is enabled, as otherwise
these hooks are not necessary. Hook code below:

```
    def after_backward_hook(_nonuse_grads):
        """the hook registered to all leaf tensors"""
        nonlocal leaf_tensors, backward_visited_leaf_nodes
        backward_visited_leaf_nodes += 1

        if backward_visited_leaf_nodes == len(leaf_tensors):
            see_memory_usage("After backward checkpointing code after backward", force=False)

            if PROFILE_TIME:
                timers('backward').stop()
                timers.log(['backward'])
            if SYNCHRONIZE:
                get_accelerator().synchronize()
```

see_memory_usage is nevel used, as `force` is hardcoded to `False`. Thus
this hook only does any real work only when PROFILE_TIME or SYNCHRONIZE
is True. Otherwise it creates unnecessary function calls

Co-authored-by: Heyang Qin <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Aug 2, 2024
1 parent 029bb52 commit 3a2d526
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,8 +965,9 @@ def after_backward_hook(_nonuse_grads):

with torch.autograd.graph.saved_tensors_hooks(checkpoint_pack, checkpoint_unpack):
outputs = function(*inputs_cuda)
for leaf_tensor in leaf_tensors:
leaf_tensor.register_hook(after_backward_hook)
if PROFILE_TIME or SYNCHRONIZE:
for leaf_tensor in leaf_tensors:
leaf_tensor.register_hook(after_backward_hook)

see_memory_usage("After running forward on the layer", force=False)

Expand Down

0 comments on commit 3a2d526

Please sign in to comment.