Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Non-reentrant checkpointing hook fix (microsoft#5781)
This PR adds an extra condition to attach backward pass hooks to leaf nodes only if Synchronisation or Profiling is enabled, as otherwise these hooks are not necessary. Hook code below: ``` def after_backward_hook(_nonuse_grads): """the hook registered to all leaf tensors""" nonlocal leaf_tensors, backward_visited_leaf_nodes backward_visited_leaf_nodes += 1 if backward_visited_leaf_nodes == len(leaf_tensors): see_memory_usage("After backward checkpointing code after backward", force=False) if PROFILE_TIME: timers('backward').stop() timers.log(['backward']) if SYNCHRONIZE: get_accelerator().synchronize() ``` see_memory_usage is nevel used, as `force` is hardcoded to `False`. Thus this hook only does any real work only when PROFILE_TIME or SYNCHRONIZE is True. Otherwise it creates unnecessary function calls Co-authored-by: Heyang Qin <[email protected]> Co-authored-by: Logan Adams <[email protected]>
- Loading branch information