diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 79c9379ccf1d..bc9425a0b0cd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1488,7 +1488,7 @@ def seed_worker(worker_id): ) def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert ( diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 96531a04fd69..12ee54f34106 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -404,7 +404,7 @@ def __init__( def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( - self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage + self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.ep_group, self.moe_dp_group, self.zero_stage ) def configure( diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 125c29370ad2..2bcd61db56db 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,6 +1,5 @@ import os -from pathlib import Path -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Any import torch import torch.nn as nn @@ -52,10 +51,40 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" if checkpoint.endswith(".safetensors"): - checkpoint = load_flat(checkpoint, seperator="-") + checkpoint = load_flat(checkpoint, seperator=".") else: checkpoint = utils.load_state_dict(checkpoint) + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False) + start_index = 0 + id2name = {} + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + start_num = len(id2name) + id2name.update( + { + i: p + for i, p in enumerate(group["params"], start_index) + if i not in id2name + } + ) + end_num = len(id2name) + start_index += end_num - start_num + + for g in full_optimizer_state["param_groups"]: + get_index_mapping(g) + + new_state = {} + for key, value in checkpoint["state"].items(): + new_state[id2name[int(key)]] = value + checkpoint["state"] = new_state + for g in checkpoint["param_groups"]: + new_group = [] + for param_id in g["params"]: + new_group.append(id2name[param_id]) + g["params"] = new_group + sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) @@ -70,18 +99,19 @@ def save_unsharded_model( cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): full_model_state = model.state_dict() - if use_async: - from colossalai.utils.safetensors import save - - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) - for k, v in full_model_state.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - full_model_state[k] = self.pinned_state_dicts[id(model)][k] - writer = save(checkpoint, full_model_state) - self.async_writers.append(writer) - else: - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + if self.coordinator.is_master(): + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) + for k, v in full_model_state.items(): + self.pinned_state_dicts[id(model)][k].copy_(v) + full_model_state[k] = self.pinned_state_dicts[id(model)][k] + writer = save(checkpoint, full_model_state) + self.async_writers.append(writer) + else: + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) def save_unsharded_optimizer( self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False @@ -91,20 +121,48 @@ def save_unsharded_optimizer( """ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) - if use_async: - from colossalai.utils.safetensors import _flatten_optim_state_dict, save - - flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator="-") - if id(optimizer) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) - for k, v in flatten_state_dict.items(): - self.pinned_state_dicts[id(optimizer)][k].copy_(v) - flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] - writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata) - self.async_writers.append(writer) - else: - utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) + + if self.coordinator.is_master(): + + # Save order indices instead of Tensors + name2id: Dict[str, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + name2id.update( + { + p: i + for i, p in enumerate(group["params"], start_index) + if p not in name2id + } + ) + packed["params"] = [name2id[p] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]] + full_optimizer_state["param_groups"] = param_groups + new_state = {} + for key, value in full_optimizer_state["state"].items(): + new_state[name2id[key]] = value + full_optimizer_state["state"] = new_state + + if use_async: + from colossalai.utils.safetensors import _flatten_optim_state_dict, save + flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".") + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[id(optimizer)][k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] + writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata) + self.async_writers.append(writer) + else: + utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) def save_sharded_model( self, @@ -150,7 +208,7 @@ def save_sharded_model( checkpoint=checkpoint_path, index_file=index_file, base_filename=weights_name, - is_master=True, + is_master=self.coordinator.is_master(), ) self.async_writers.extend(writers) else: @@ -234,6 +292,32 @@ def save_sharded_optimizer( ) if self.coordinator.is_master(): + + # Save order indices instead of Tensors + name2id: Dict[str, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + name2id.update( + { + p: i + for i, p in enumerate(group["params"], start_index) + if p not in name2id + } + ) + packed["params"] = [name2id[p] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]] + fsdp_optim_state["param_groups"] = param_groups + new_state = {} + for key, value in fsdp_optim_state["state"].items(): + new_state[name2id[key]] = value + fsdp_optim_state["state"] = new_state + # Preparing file paths and index file. states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames( prefix, use_safetensors=use_async @@ -261,7 +345,7 @@ def save_sharded_optimizer( checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=True, + is_master=self.coordinator.is_master(), state_preprocess=True, ) self.async_writers.extend(writers) @@ -306,13 +390,43 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: if shard_file.endswith(".safetensors"): - state_dict_shard = load_flat(shard_file, seperator="-") + state_dict_shard = load_flat(shard_file, seperator=".") else: state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) fsdp_optim_state.update(state_dict_shard) fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False) + start_index = 0 + id2name = {} + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + start_num = len(id2name) + id2name.update( + { + i: p + for i, p in enumerate(group["params"], start_index) + if i not in id2name + } + ) + end_num = len(id2name) + start_index += end_num - start_num + + for g in full_optimizer_state["param_groups"]: + get_index_mapping(g) + + new_state = {} + for key, value in fsdp_optim_dict["state"].items(): + new_state[id2name[int(key)]] = value + fsdp_optim_dict["state"] = new_state + for g in fsdp_optim_dict["param_groups"]: + new_group = [] + for param_id in g["params"]: + new_group.append(id2name[param_id]) + g["params"] = new_group + with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT): fsdp_state = FSDP.optim_state_dict_to_load( model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index d553294d0838..e7c203910e14 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -70,6 +70,7 @@ def __init__( dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, zero_stage: int, verbose: bool = True, ) -> None: @@ -77,9 +78,11 @@ def __init__( self.global_dp_group = dp_group self.pp_group = pp_group self.tp_group = tp_group + self.sp_group = sp_group self.dp_rank = dist.get_rank(self.global_dp_group) self.tp_rank = dist.get_rank(self.tp_group) self.pp_rank = dist.get_rank(self.pp_group) + self.sp_rank = dist.get_rank(self.sp_group) self.global_dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) @@ -490,7 +493,7 @@ def save_sharded_optimizer( # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. - control_saving = self.dp_rank == 0 and self.tp_rank == 0 + control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0 if use_async and control_saving: if id(optimizer) not in self.pinned_state_dicts: @@ -560,8 +563,10 @@ def save_sharded_optimizer( Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") - states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") + if not use_async: + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + else: + states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 3b07856ca06c..244f5bc0b644 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -44,12 +44,13 @@ def __init__( global_dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, ep_group: ProcessGroup, moe_dp_group: ProcessGroup, zero_stage: int, verbose: bool = True, ) -> None: - super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose) + super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose) self.global_dp_group = global_dp_group self.global_dp_rank = dist.get_rank(global_dp_group) self.global_dp_size = dist.get_world_size(global_dp_group) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 8f1589831821..a434ac6bd57c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -306,7 +306,7 @@ def async_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator="-") + state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".") else: state_dict = shard diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 60ace5f57442..b412953dd76e 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -59,7 +59,7 @@ def _cast_to_object(tensor: torch.Tensor): return _tensor_to_object(tensor, tensor.numel() * tensor.element_size()) -def _flatten_optim_state_dict(state_dict: dict, seperator: str = "-") -> Tuple[dict, Optional[dict]]: +def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]: flat_dict = {} non_tensor_keys = [] if "state" in state_dict: @@ -196,7 +196,7 @@ def move_and_save( return f_writer -def load_flat(checkpoint_path, seperator: str = "-"): +def load_flat(checkpoint_path, seperator: str = "."): with safe_open(checkpoint_path, framework="pt") as f: metadata = f.metadata() state_dict_load = load_file(checkpoint_path) diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index b362875d054d..25d901538064 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -114,8 +114,8 @@ def run_model(): run_model() - booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=False) - booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=True) + booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=use_async) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=use_async) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io()