Skip to content

Commit

Permalink
Merge branch 'mblaz/pyt-dist-mr-draft' into 'main'
Browse files Browse the repository at this point in the history
Add PyT Distributed checkpoint format

See merge request ADLR/megatron-lm!1064
  • Loading branch information
jaredcasper committed Mar 13, 2024
2 parents f0f8150 + baa76c7 commit 73ce965
Show file tree
Hide file tree
Showing 8 changed files with 610 additions and 62 deletions.
4 changes: 2 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,8 +1109,8 @@ def _add_checkpointing_args(parser):
help='Determine if the checkpoint format is in legacy or distributed format.'
' If False, expects distributed checkpoint iff args.use_dist_ckpt.'
' Might slow down loading a bit (double rank0 ckpt load).')
group.add_argument('--dist-ckpt-format', type=str, default='zarr',
choices=['zarr'],
group.add_argument('--dist-ckpt-format', type=str, default='torch_dist',
choices=['zarr', 'torch_dist'],
help='Distributed checkpoint format to use.')

return parser
Expand Down
18 changes: 15 additions & 3 deletions megatron/core/dist_checkpointing/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import logging
from abc import ABC
from dataclasses import dataclass, replace
from itertools import chain
from typing import Any, Callable, Dict, Optional, Tuple, Union
Expand All @@ -27,8 +28,14 @@
ReplicaId = Union[int, Tuple[int, ...]]


class ShardedBase(ABC):
key: str
data: object
replica_id: ReplicaId


@dataclass
class ShardedTensor:
class ShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
Expand Down Expand Up @@ -173,6 +180,11 @@ def from_rank_offsets(
allow_shape_mismatch,
)

def init_data(self, device: torch.device, init_fn=torch.empty):
if self.data is not None:
return
self.data = init_fn(self.local_shape, dtype=self.dtype, device=device)

def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'

Expand Down Expand Up @@ -214,7 +226,7 @@ def unwrap(self):


@dataclass
class ShardedObject:
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
Expand Down Expand Up @@ -250,7 +262,7 @@ def __str__(self):


@dataclass
class ShardedTensorFactory:
class ShardedTensorFactory(ShardedBase):
""" Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
Expand Down
98 changes: 66 additions & 32 deletions megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
StrategyAction,
get_default_strategy,
)
from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent
from .utils import (
extract_nonpersistent,
extract_sharded_base,
extract_sharded_tensors,
extract_sharded_tensors_or_nonpersistent,
)

COMMON_STATE_FNAME = 'common.pt'

Expand All @@ -61,6 +66,17 @@ def load(
) -> StateDict:
"""Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
- load = load from checkpoint
- extract = extract from sharded_state_dict
- add = add to the final state dict
Steps:
1. Load common state dict and form the base of the result state dict
2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Arguments:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
Expand All @@ -81,20 +97,27 @@ def load(
if not sharded_state_dict:
return common_state_dict

sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir)
merge(common_state_dict, sharded_objects)

sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict)

# Non-persistent objects
nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
merge(common_state_dict, nonpersistent_state_dict)

# Sharded base
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_objects, sharded_state_dict = load_sharded_objects(
sharded_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)
sharded_state_dict, _ = extract_sharded_base(sharded_state_dict)

if validate_access_integrity:
validate_sharding_integrity(nested_values(sharded_state_dict))

Expand Down Expand Up @@ -228,14 +251,22 @@ def save(
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
):
) -> None:
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Steps:
1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
Arguments:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
Expand Down Expand Up @@ -269,29 +300,33 @@ def save(
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)

apply_factories(sharded_state_dict)
sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict)
sharded_tensors = list(nested_values(sharded_state_dict))
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_state_dict, state_dict = extract_sharded_base(sharded_state_dict)
_save_common_dict(state_dict, checkpoint_dir, True)

if validate_access_integrity:
validate_sharding_integrity(sharded_tensors)
validate_sharding_integrity(list(nested_values(sharded_state_dict)))

_save_common_dict(state_dict, checkpoint_dir, True)
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_state_dict = _extract_and_save_sharded_objects(
sharded_state_dict, checkpoint_dir, validate_access_integrity
)

sharded_strategy.save(sharded_tensors, checkpoint_dir)
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir
)
sharded_strategy.save(sharded_state_dict, checkpoint_dir)
if torch.distributed.get_rank() == 0:
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir
)
torch.distributed.barrier()


# TODO: implement it as common torch strategy
def _save_common_dict(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
common_state_dict = _extract_and_save_sharded_objects(
state_dict, checkpoint_dir, validate_consistency
)
if torch.distributed.get_rank() == 0:
torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)
torch.save(state_dict, checkpoint_dir / COMMON_STATE_FNAME)
if validate_consistency:
# TODO: implement checking consistency with rank 0 common dict on other ranks
pass
Expand All @@ -308,8 +343,6 @@ def _extract_and_save_sharded_objects(
state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = list(nested_values(sharded_objects))
if validate_consistency:
validate_objects_sharding_integrity(sharded_objects)
for sh_obj in sharded_objects:
if is_main_replica(sh_obj.replica_id):
save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
Expand Down Expand Up @@ -346,7 +379,10 @@ def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
_validate_sharding_for_key(shardings)
if isinstance(shardings[0][1], ShardedObject):
_validate_objects_for_key(shardings)
else:
_validate_sharding_for_key(shardings)


def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
Expand Down Expand Up @@ -438,19 +474,17 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
)


def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]):
def _validate_objects_for_key(sharded_objects: List[ShardedObject]):
""" Ensure uniqueness of saved objects. """
local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects]
all_sh_objs = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_sh_objs, local_sh_objs)
if torch.distributed.get_rank() != 0:
return
unique_keys = [
sh_obj.unique_key
for sh_obj in chain.from_iterable(all_sh_objs)
if is_main_replica(sh_obj.replica_id)
sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id)
]
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
expected_shard_num = np.prod(sharded_objects[0][1].global_shape)
if len(unique_keys) != expected_shard_num:
err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.'
logger.error(f'{err_msg} Existing shards: {unique_keys}')
raise CheckpointingException(err_msg)
17 changes: 0 additions & 17 deletions megatron/core/dist_checkpointing/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,3 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.

""" Various loading and saving strategies """

import logging

logger = logging.getLogger(__name__)

try:
import tensorstore
import zarr

from .tensorstore import _import_trigger
from .zarr import _import_trigger
except ImportError:
# Only print warning on first rank.
import os

if int(os.getenv('RANK', '0')) == 0:
logger.warning('Zarr-based strategies will not be registered because of missing packages')
31 changes: 23 additions & 8 deletions megatron/core/dist_checkpointing/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@ class StrategyAction(Enum):

def get_default_strategy(action: StrategyAction, backend: str, version: int):
""" Retrieves a default strategy for a given action, backend and version. """
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
from .tensorstore import _import_trigger
from .zarr import _import_trigger
elif backend == 'torch_dist':
error_hint = ' Please use PyTorch version >=2.1'
from .torch import _import_trigger
except ImportError as e:
raise CheckpointingException(
f'Cannot import a default strategy for: {(action.value, backend, version)}. Error: {e}. Hint: {error_hint}'
) from e
try:
return default_strategies[action.value][(backend, version)]
except KeyError as e:
hint = ''
if backend == 'zarr':
try:
import tensorstore
import zarr
except ImportError:
hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
raise CheckpointingException(
f'Cannot find a default strategy for: {(action.value, backend, version)}.{hint}'
f'Cannot find a default strategy for: {(action.value, backend, version)}'
) from e


Expand All @@ -49,6 +54,11 @@ def check_backend_compatibility(self, loaded_version):
def check_version_compatibility(self, loaded_version):
raise NotImplementedError

@property
def can_handle_sharded_objects(self):
""" Returns whether or not this strategy can handle loading ShardedObjects. """
return False


class SaveStrategyBase(ABC):
""" Base class for a save strategy. Requires defining a backend type and version of the saved format. """
Expand All @@ -57,6 +67,11 @@ def __init__(self, backend: str, version: int):
self.backend = backend
self.version = version

@property
def can_handle_sharded_objects(self):
""" Returns whether or not this strategy can handle saving ShardedObjects. """
return False


class LoadCommonStrategy(LoadStrategyBase):
""" Load strategy for common (non-sharded) objects """
Expand Down
Loading

0 comments on commit 73ce965

Please sign in to comment.