diff --git a/torchsnapshot/manifest_ops.py b/torchsnapshot/manifest_ops.py index 9e45849..138b037 100644 --- a/torchsnapshot/manifest_ops.py +++ b/torchsnapshot/manifest_ops.py @@ -126,8 +126,8 @@ def handle_sharded_tensor_elasticity( :class:`ShardedTensor` can be elastic in several ways: - A rank loads a portion of a sharded tensor different from what it saved - - A rank loads a sharded tensor that it did not participate in saving - - A rank doesn't load a sharded tensor that it participated in saving + - A rank loads a sharded tensor that did not participate in saving + - A rank doesn't load a sharded tensor that participated in saving The first scenario is taken care of by :func:`get_manifest_for_rank`, which makes all shards available to all instances of :class:`ShardedTensorEntry`. @@ -143,7 +143,7 @@ def handle_sharded_tensor_elasticity( NOTE: this function only takes effect if all sharded tensors are at the root of the state dict. This means the elastic behavior is supported for - most model but not supported for most optimizers. + most models but not supported for most optimizers. Args: manifest: The local manifest for the rank.