From 1239a28c0ef37f81da75e561206f07155f091575 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 10 Dec 2024 11:08:45 +0800 Subject: [PATCH 01/18] fix --- .../checkpoint_io/general_checkpoint_io.py | 4 +- .../hybrid_parallel_checkpoint_io.py | 234 ++++++++++++++---- colossalai/checkpoint_io/utils.py | 61 ++++- ...st_hybrid_parallel_plugin_checkpoint_io.py | 17 +- 4 files changed, 256 insertions(+), 60 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 54da168e54d0..a75a625e1920 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -11,7 +11,7 @@ from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( - async_save_state_dict_shards, + async_move_save_state_dict_shards, create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, @@ -186,7 +186,7 @@ def save_sharded_model( if use_async: pinned_state_dict = self.pinned_state_dicts.get(id(model), None) - total_size, new_pinned_state_dict, writers = async_save_state_dict_shards( + total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint_path, index_file=index_file, diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index e0701a247b54..4c1defe27556 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -22,6 +22,7 @@ to_unpadded_tensor, ) from colossalai.utils import get_current_device, get_non_persistent_buffers_set +from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat, save from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -88,7 +89,11 @@ def __init__( @staticmethod def _model_sharder( - model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024 + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. @@ -102,6 +107,13 @@ def _model_sharder( if is_padded_tensor(param): param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(param_) + param_ = pinned_state_dicts[prefix + name] + else: + param_ = param_.cpu() block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -111,6 +123,13 @@ def _model_sharder( for name, buf in model.named_buffers(): if buf is not None and name not in non_persist_buffers_set: buffer = buf if keep_vars else buf.detach() + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] + else: + buffer = buffer.cpu() block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -122,6 +141,13 @@ def _model_sharder( is not torch.nn.Module.get_extra_state ): extra_state = model.get_extra_state() + if pinned_state_dicts is not None: + if extra_state_key not in pinned_state_dicts: + pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] + else: + extra_state = extra_state.cpu() block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size @@ -136,6 +162,7 @@ def _optimizer_sharder( dp_group: ProcessGroup, tp_group: ProcessGroup, size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, ): # An internel method that breaks state_dict of optimizer into shards within limited size. @@ -153,6 +180,9 @@ def _optimizer_sharder( working_param = param param_id = param_info["param2id"][id(working_param)] + if pinned_state_dicts is not None: + if param_id not in pinned_state_dicts: + pinned_state_dicts[param_id] = {} original_shape = param_info["param2shape"][id(working_param)] state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( state, @@ -162,6 +192,7 @@ def _optimizer_sharder( tp_group=tp_group, use_zero=use_zero, inplace=False, + pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None, ) block, block_size = state_dict_sharder.append_optim_state(param_id, state_) @@ -216,15 +247,31 @@ def save_sharded_model( # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + control_saving = self.tp_rank == 0 + if control_saving and use_async: + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = HybridParallelCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + state_preprocess=False, + ) + self.async_writers.extend(writers) else: total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -234,16 +281,16 @@ def save_sharded_model( is_master=control_saving, use_safetensors=use_safetensors, ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -259,24 +306,25 @@ def save_sharded_model( save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) if use_async: - total_size, returned_state_dict, writers = async_save_state_dict_shards( + total_size, writers = async_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=weights_name, is_master=control_saving, + state_preprocess=False, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, use_pp_format=True, - n_write_entries=191, ) - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - use_pp_format=True, - ) if control_saving: assert ( @@ -448,26 +496,46 @@ def save_sharded_optimizer( # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + control_saving = self.dp_rank == 0 and self.tp_rank == 0 + + if use_async and control_saving: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder( optimizer, use_zero=self.use_zero, dp_group=self.global_dp_group, tp_group=self.tp_group, size_per_shard=size_per_shard, + pinned_state_dicts=pinned_state_dicts, ) - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.dp_rank == 0 and self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) if control_saving: # Store param groups. @@ -499,17 +567,30 @@ def save_sharded_optimizer( # Manage filenames of sharded weights and index file for each pipeline stage. states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - use_pp_format=True, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) if control_saving: assert ( @@ -622,7 +703,10 @@ def _get_param_id_from_optimizer_param( continue file_path = os.path.join(ckpt_root_path, filename) - state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + if file_path.endswith(".safetensors"): + state_dict = load_flat(file_path) + else: + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) @@ -672,7 +756,14 @@ def save_unsharded_model( # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + from colossalai.utils.safetensors import save + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint, state_dict=state_dict) + self.async_writers.append(writer) else: save_state_dict(state_dict, checkpoint, use_safetensors) else: @@ -686,12 +777,12 @@ def save_unsharded_model( for _state_dict in state_dict_list: complete_state_dict.update(_state_dict) if use_async: - - from colossalai.utils.safetensors import move_and_save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) + for name, param in complete_state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint, state_dict=complete_state_dict) self.async_writers.append(writer) else: save_state_dict(complete_state_dict, checkpoint, use_safetensors) @@ -757,6 +848,7 @@ def save_unsharded_optimizer( # gather complete state from tp shards & dp shards param_id = optimizer.param_info["param2id"][id(working_param)] original_shape = optimizer.param_info["param2shape"][id(working_param)] + local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( state, working_param, @@ -776,7 +868,20 @@ def save_unsharded_optimizer( ] state_dict = {"param_groups": param_groups, "state": local_states} if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[k] + self.async_writers.append(writer) + save(f_writer=writer, state_dict=flatten_state_dict, metadata=metadata) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. states_list = [None for _ in range(self.pp_size)] @@ -792,7 +897,20 @@ def save_unsharded_optimizer( state_dict = {"param_groups": param_groups, "state": dict()} for _states in states_list: state_dict["state"].update(_states) - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[k] + self.async_writers.append(writer) + save(f_writer=writer, state_dict=flatten_state_dict, metadata=metadata) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): """ @@ -818,7 +936,10 @@ def _get_param_id_from_optimizer_param( assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" # Complete optimizer state_dict loaded from checkpoint, need to be processed later. - state_dict = load_state_dict(checkpoint) + if checkpoint.endswith(".safetensors"): + state_dict = load_flat(checkpoint) + else: + state_dict = load_state_dict(checkpoint) # Load param_groups. updated_groups = [] @@ -872,6 +993,7 @@ def gather_from_sharded_optimizer_state( use_zero: bool, inplace: bool, device: torch.device = torch.device("cpu"), + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> OrderedDict: """ With given parameter and its optimizer states, gather the complete optimizer state for saving. @@ -915,7 +1037,13 @@ def gather_from_sharded_optimizer_state( v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) v = to_unpadded_tensor(v) - state_[k] = v.detach().clone().to(device) + if pinned_state_dicts is not None: + if k not in pinned_state_dicts: + pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu") + pinned_state_dicts[k].copy_(v) + state_[k] = pinned_state_dicts[k] + else: + state_[k] = v.detach().clone().to(device) return state_ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index ab599b556937..920bd19a271a 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -19,6 +19,7 @@ to_global, to_global_for_customized_distributed_tensor, ) +from colossalai.utils.safetensors import _flatten_optim_state_dict SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -266,6 +267,63 @@ def save_state_dict_shards( def async_save_state_dict_shards( + sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + use_pp_format: bool = False, + state_preprocess: bool = False, +) -> Tuple[int, list]: + """ + Save sharded state dict only on master rank, this method can be used by both model and optimizer states. + Args: + sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. + checkpoint (str): The path of checkpoint directory as string. + index_file (CheckpointIndexFile): The index file object to be updated. + base_filename (str): Decides the prefix of filenames of shards. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + Returns: + int: the total size of shards + """ + from colossalai.utils.safetensors import save + + total_size = 0 + shard_filenames = [] + writers = [] + for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair + # Just loop over the sharder and gather to other ranks if not master + if not is_master: + del shard + continue + shard_file = get_shard_filename(base_filename, idx) + total_size = total_size + current_size + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint, shard_file) + + if state_preprocess: + state_dict, _ = _flatten_optim_state_dict(state_dict=shard) + else: + state_dict = shard + + # Only save on master rank. + writer = save(checkpoint_file_path, state_dict=state_dict) + writers.append(writer) + shard_filenames.append(shard_file) + del shard + + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + + return total_size, writers + + +def async_move_save_state_dict_shards( sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", @@ -864,5 +922,6 @@ def get_shard_filename(weights_name: str, idx: int): def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]): pin_mem = dict() for name, tensor in state_dict.items(): - pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu") + pin_mem[name] = torch.empty(tensor.shape, pin_memory=False, dtype=tensor.dtype, device="cpu", requires_grad=False) return pin_mem + diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 86d7924fb828..81d184f7681a 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -38,12 +38,13 @@ ] -@parameterize("shard", [True, False]) +@parameterize("shard", [False, True]) @parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) +@parameterize("use_async", [False, True]) @clear_cache_before_run() -def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) ) @@ -85,8 +86,16 @@ def _preprocess_data(data): with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.save_optimizer( + optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = model_fn().cuda() From bcb66278a468111d24c915c97cbdfd7aa526aae2 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 26 Nov 2024 10:48:30 +0000 Subject: [PATCH 02/18] fix --- .../test_hybrid_parallel_plugin_checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 81d184f7681a..d6cbb95db87a 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -42,7 +42,7 @@ @parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) -@parameterize("use_async", [False, True]) +@parameterize("use_async", [False]) @clear_cache_before_run() def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( From af439f07a079005ad1410cccf65cf3188cfdd058 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 26 Nov 2024 12:51:44 +0000 Subject: [PATCH 03/18] fix --- .../test_hybrid_parallel_plugin_checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index d6cbb95db87a..81d184f7681a 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -42,7 +42,7 @@ @parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) -@parameterize("use_async", [False]) +@parameterize("use_async", [False, True]) @clear_cache_before_run() def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( From 4bf922460aea50fa6c9951594e0ef351e1878db1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 27 Nov 2024 14:54:40 +0800 Subject: [PATCH 04/18] fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 4c1defe27556..9373379e5ee5 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -112,8 +112,6 @@ def _model_sharder( pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") pinned_state_dicts[prefix + name].copy_(param_) param_ = pinned_state_dicts[prefix + name] - else: - param_ = param_.cpu() block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -128,8 +126,6 @@ def _model_sharder( pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") pinned_state_dicts[prefix + name].copy_(buffer) buffer = pinned_state_dicts[prefix + name] - else: - buffer = buffer.cpu() block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -146,8 +142,6 @@ def _model_sharder( pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu") pinned_state_dicts[extra_state_key].copy_(extra_state) extra_state = pinned_state_dicts[extra_state_key] - else: - extra_state = extra_state.cpu() block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size From db55aa8b963cd54288c5a7ea3e5c7e84647a6374 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 03:11:11 +0000 Subject: [PATCH 05/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 1 + colossalai/checkpoint_io/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 9373379e5ee5..5cf5b9019f61 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -751,6 +751,7 @@ def save_unsharded_model( if self.tp_rank == 0: if use_async: from colossalai.utils.safetensors import save + if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) for name, param in state_dict.items(): diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 920bd19a271a..ed3ec4e0ecff 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -922,6 +922,7 @@ def get_shard_filename(weights_name: str, idx: int): def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]): pin_mem = dict() for name, tensor in state_dict.items(): - pin_mem[name] = torch.empty(tensor.shape, pin_memory=False, dtype=tensor.dtype, device="cpu", requires_grad=False) + pin_mem[name] = torch.empty( + tensor.shape, pin_memory=False, dtype=tensor.dtype, device="cpu", requires_grad=False + ) return pin_mem - From b8b7fa2609df99cb309e6d93cb3cb7c315599da8 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 10 Dec 2024 11:25:37 +0800 Subject: [PATCH 06/18] fix --- .../checkpoint_io/hybrid_parallel_checkpoint_io.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5cf5b9019f61..f63e17a58be5 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -864,17 +864,15 @@ def save_unsharded_optimizer( state_dict = {"param_groups": param_groups, "state": local_states} if self.coordinator.is_master(): if use_async: - from tensornvme.async_file_io import AsyncFileWriter - - writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") + from colossalai.utils.safetensors import save flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) for k, v in flatten_state_dict.items(): self.pinned_state_dicts[k].copy_(v) flatten_state_dict[k] = self.pinned_state_dicts[k] + writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata) self.async_writers.append(writer) - save(f_writer=writer, state_dict=flatten_state_dict, metadata=metadata) else: save_state_dict(state_dict, checkpoint, use_safetensors=False) else: @@ -893,17 +891,15 @@ def save_unsharded_optimizer( for _states in states_list: state_dict["state"].update(_states) if use_async: - from tensornvme.async_file_io import AsyncFileWriter - - writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") + from colossalai.utils.safetensors import save flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) for k, v in flatten_state_dict.items(): self.pinned_state_dicts[k].copy_(v) flatten_state_dict[k] = self.pinned_state_dicts[k] + writer=save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata) self.async_writers.append(writer) - save(f_writer=writer, state_dict=flatten_state_dict, metadata=metadata) else: save_state_dict(state_dict, checkpoint, use_safetensors=False) From 6d9906d088ae319c992df911f76e993eba0627b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 03:26:35 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index f63e17a58be5..1cda03d84f27 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -22,7 +22,7 @@ to_unpadded_tensor, ) from colossalai.utils import get_current_device, get_non_persistent_buffers_set -from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat, save +from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -865,6 +865,7 @@ def save_unsharded_optimizer( if self.coordinator.is_master(): if use_async: from colossalai.utils.safetensors import save + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) @@ -892,13 +893,14 @@ def save_unsharded_optimizer( state_dict["state"].update(_states) if use_async: from colossalai.utils.safetensors import save + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) for k, v in flatten_state_dict.items(): self.pinned_state_dicts[k].copy_(v) flatten_state_dict[k] = self.pinned_state_dicts[k] - writer=save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata) + writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata) self.async_writers.append(writer) else: save_state_dict(state_dict, checkpoint, use_safetensors=False) From 8cc57406e78bfe60c1f66091be4f6b42ac2dc04a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 10 Dec 2024 15:52:22 +0800 Subject: [PATCH 08/18] fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 1cda03d84f27..d2dd1b9ebc15 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -772,6 +772,7 @@ def save_unsharded_model( for _state_dict in state_dict_list: complete_state_dict.update(_state_dict) if use_async: + from colossalai.utils.safetensors import save if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) for name, param in complete_state_dict.items(): From 3f902c74c3289a7384409afb1228a1d5bfd2c9e9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 07:53:19 +0000 Subject: [PATCH 09/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index d2dd1b9ebc15..d553294d0838 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -773,6 +773,7 @@ def save_unsharded_model( complete_state_dict.update(_state_dict) if use_async: from colossalai.utils.safetensors import save + if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) for name, param in complete_state_dict.items(): From 237652dfc7198c460c3f931178eb63b4ad39a93f Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 11 Dec 2024 18:46:39 +0800 Subject: [PATCH 10/18] support for other models --- colossalai/booster/plugin/gemini_plugin.py | 109 +++++++++++++----- colossalai/booster/plugin/torch_ddp_plugin.py | 2 +- .../booster/plugin/torch_fsdp_plugin.py | 109 ++++++++++++++---- .../checkpoint_io/general_checkpoint_io.py | 63 +++++++--- colossalai/checkpoint_io/utils.py | 34 ++++-- colossalai/utils/safetensors.py | 23 ++-- colossalai/zero/gemini/gemini_ddp.py | 16 +++ colossalai/zero/gemini/gemini_optimizer.py | 8 +- .../test_gemini_checkpoint_io.py | 33 +++++- .../test_general_checkpoint_io.py | 57 ++++++--- .../test_torch_ddp_checkpoint_io.py | 16 ++- .../test_torch_fsdp_checkpoint_io.py | 24 ++-- 12 files changed, 371 insertions(+), 123 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 30c1257ef14c..b079abcb6b1c 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -23,6 +23,8 @@ save_config_file, save_state_dict, save_state_dict_shards, + async_save_state_dict_shards, + create_pinned_state_dict ) from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -32,6 +34,7 @@ from colossalai.zero.gemini.memory_tracer import MemStats from .dp_plugin_base import DPPluginBase +from colossalai.utils.safetensors import load_flat __all__ = ["GeminiPlugin"] @@ -82,7 +85,14 @@ def save_unsharded_model( state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + from colossalai.utils.safetensors import save + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for k, v in state_dict.items(): + self.pinned_state_dicts[id(model)][k].copy_(v) + state_dict[k] = self.pinned_state_dicts[id(model)][k] + writer = save(checkpoint, state_dict) + self.async_writers.append(writer) else: save_state_dict(state_dict, checkpoint, use_safetensors) @@ -106,7 +116,18 @@ def save_unsharded_optimizer( assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" state_dict = optimizer.state_dict() if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from colossalai.utils.safetensors import save, _flatten_optim_state_dict + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[id(optimizer)][k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] + writer = save(checkpoint, flatten_state_dict, metadata) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): """ @@ -136,18 +157,28 @@ def save_sharded_model( return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) + + if use_async and self.coordinator.is_master(): + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) # Save shards of optimizer states. is_master = self.coordinator.is_master() if use_async: - super().save_sharded_model( - model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, ) - + self.async_writers.extend(writers) else: total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -158,17 +189,17 @@ def save_sharded_model( use_safetensors=use_safetensors, ) - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.", - ranks=[0], - ) + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.", + ranks=[0], + ) def load_sharded_model( self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False @@ -201,7 +232,7 @@ def save_sharded_optimizer( Path(checkpoint).mkdir(parents=True, exist_ok=True) # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) index_file.append_meta_data("param_groups", param_group_file) @@ -212,17 +243,34 @@ def save_sharded_optimizer( torch.save(param_groups, group_file_path) # States are broken into shards within max_shard_size. - state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) + if use_async and self.coordinator.is_master(): + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts) # Save shards of optimizer states. - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=self.coordinator.is_master(), - use_safetensors=False, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + state_preprocess=True + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) # Wrap up index file. Only save it on master rank. if self.coordinator.is_master(): @@ -264,7 +312,10 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi # Load optimizer states from shard files under checkpoint path. # For each file, only load the states managed by current process. for shard_file in checkpoint_files: - state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + state_dict_shard = load_flat(shard_file) + else: + state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) optimizer.load_param_states(state_dict_shard) del state_dict_shard gc.collect() diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 07be5b0516f6..90d406eefaa3 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -60,7 +60,7 @@ def save_unsharded_optimizer( """ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): - super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index b80d6d4b6eb8..a969b6ea9de5 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -29,6 +29,8 @@ from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.logging import get_dist_logger +from colossalai.checkpoint_io.utils import create_pinned_state_dict, async_save_state_dict_shards +from colossalai.utils.safetensors import load_flat from .dp_plugin_base import DPPluginBase @@ -49,7 +51,10 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" - checkpoint = utils.load_state_dict(checkpoint) + if checkpoint.endswith(".safetensors"): + checkpoint = load_flat(checkpoint, seperator='-') + else: + checkpoint = utils.load_state_dict(checkpoint) fsdp_model = optimizer.unwrap_model() sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) @@ -65,7 +70,17 @@ def save_unsharded_model( cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): full_model_state = model.state_dict() - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + if use_async: + from colossalai.utils.safetensors import save + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) + for k, v in full_model_state.items(): + self.pinned_state_dicts[id(model)][k].copy_(v) + full_model_state[k] = self.pinned_state_dicts[id(model)][k] + writer = save(checkpoint, full_model_state) + self.async_writers.append(writer) + else: + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) def save_unsharded_optimizer( self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False @@ -76,7 +91,18 @@ def save_unsharded_optimizer( assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" fsdp_model = optimizer.unwrap_model() full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) - utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) + if use_async: + from colossalai.utils.safetensors import save, _flatten_optim_state_dict + flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator='-') + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[id(optimizer)][k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] + writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata) + self.async_writers.append(writer) + else: + utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) def save_sharded_model( self, @@ -102,20 +128,36 @@ def save_sharded_model( ): state_dict = model.unwrap().state_dict() - state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard) + if use_async and self.coordinator.is_master(): + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts) weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=self.coordinator.is_master(), - use_safetensors=use_safetensors, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, + ) + self.async_writers.extend(writers) + else: + total_size = utils.save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + use_safetensors=use_safetensors, + ) # only save the index file on the master rank if self.coordinator.is_master(): @@ -189,25 +231,41 @@ def save_sharded_optimizer( if self.coordinator.is_master(): # Preparing file paths and index file. - states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) utils.save_param_groups(fsdp_optim_state, group_file_path) - sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard) - + if use_async: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts) # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=sharded_state, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=self.coordinator.is_master(), - use_safetensors=False, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + state_preprocess=True + ) + self.async_writers.extend(writers) + else: + total_size = utils.save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) @@ -239,7 +297,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz fsdp_optim_state = {} checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: - state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + state_dict_shard = load_flat(shard_file, seperator='-') + else: + state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) fsdp_optim_state.update(state_dict_shard) fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a75a625e1920..976475469658 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -30,6 +30,8 @@ sharded_optimizer_loading_epilogue, ) +from colossalai.utils.safetensors import load_flat + __all__ = ["GeneralCheckpointIO"] @@ -47,18 +49,13 @@ def save_unsharded_model( ): state_dict = model.state_dict() - # TODO(FrankLeeeee): add support for gather_dtensor - if gather_dtensor: - pass - if use_async: from colossalai.utils.safetensors import move_and_save - + if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) self.async_writers.append(writer) - else: # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) @@ -83,7 +80,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + state_dict = load_flat(shard_file) + else: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) load_states_into_optimizer(optimizer, state_dict, id_map) sharded_optimizer_loading_epilogue(optimizer) @@ -116,7 +116,7 @@ def save_sharded_optimizer( sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard) # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) # Store the information of param groups to param_group_file. @@ -126,14 +126,28 @@ def save_sharded_optimizer( # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = save_state_dict_shards( - sharded_state_dict=sharded_state, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=True, - use_safetensors=False, - ) + if use_async: + pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None) + total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + pinned_state_dict=pinned_state_dict, + state_preprocess=True + ) + self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + use_safetensors=False, + ) # Wrap up index file. index_file.append_meta_data("total_size", total_size) @@ -145,7 +159,10 @@ def save_sharded_optimizer( ) def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - checkpoint = load_state_dict(checkpoint) + if checkpoint.endswith(".safetensors"): + checkpoint = load_flat(checkpoint) + else: + checkpoint = load_state_dict(checkpoint) optimizer.load_state_dict(checkpoint) def save_unsharded_optimizer( @@ -156,7 +173,17 @@ def save_unsharded_optimizer( use_async: bool = False, ): # TODO(FrankLeeeee): handle distributed tensors - save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) + state_dict = optimizer.state_dict() + if use_async: + from colossalai.utils.safetensors import move_and_save, _flatten_optim_state_dict + + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + writer = move_and_save(path=checkpoint, state_dict=flatten_state_dict, state_dict_pinned=self.pinned_state_dicts[id(optimizer)], metadata=metadata) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) def save_sharded_model( self, diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index ed3ec4e0ecff..107846737474 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -307,7 +307,7 @@ def async_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard) + state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator='-') else: state_dict = shard @@ -331,6 +331,7 @@ def async_move_save_state_dict_shards( is_master: bool, pinned_state_dict: Optional[Dict[str, torch.Tensor]], use_pp_format: bool = False, + state_preprocess: bool = False, ) -> Tuple[int, Dict[str, torch.Tensor], list]: """ Save sharded state dict only on master rank, this method can be used by both model and optimizer states. @@ -367,14 +368,19 @@ def async_move_save_state_dict_shards( index_file.append_weight_map(key, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file) + if state_preprocess: + state_dict, _ = _flatten_optim_state_dict(state_dict=shard) + else: + state_dict = shard + if pinned_state_dict is not None: - sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()} + sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()} else: - sub_pinned_state_dict = create_pinned_state_dict(shard) + sub_pinned_state_dict = create_pinned_state_dict(state_dict) returned_state_dict.update(sub_pinned_state_dict) # Only save on master rank. - writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict) + writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict) writers.append(writer) shard_filenames.append(shard_file) del shard @@ -385,7 +391,7 @@ def async_move_save_state_dict_shards( return total_size, returned_state_dict, writers -def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: +def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[int,Dict[str, torch.Tensor]]] = None) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -394,6 +400,10 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) for key, weight in state_dict.items(): if not is_distributed_tensor(weight): + if pinned_state_dicts is not None: + pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu") + pinned_state_dicts[key].copy_(weight) + weight = pinned_state_dicts[key] block, block_size = state_dict_sharder.append_param(key, weight) if block != None: @@ -403,7 +413,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) yield state_dict_sharder.current_block, state_dict_sharder.current_block_size -def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: +def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]]=None) -> Iterator[Tuple[OrderedDict, int]]: """ Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -414,6 +424,14 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): + if pinned_state_dicts is not None: + if param_id not in pinned_state_dicts: + pinned_state_dicts[param_id] = {} + for k, v in state.items(): + pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") + pinned_state_dicts[param_id][k].copy_(v) + state[k] = pinned_state_dicts[param_id][k] + block, block_size = state_dict_sharder.append_optim_state(param_id, state) if block != None: yield block, block_size @@ -922,7 +940,5 @@ def get_shard_filename(weights_name: str, idx: int): def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]): pin_mem = dict() for name, tensor in state_dict.items(): - pin_mem[name] = torch.empty( - tensor.shape, pin_memory=False, dtype=tensor.dtype, device="cpu", requires_grad=False - ) + pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu") return pin_mem diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index d8983436d950..e8d0f1f34904 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -5,6 +5,7 @@ import torch from safetensors.torch import _TYPES, load_file, safe_open +import torch.distributed try: from tensornvme.async_file_io import AsyncFileWriter @@ -59,7 +60,7 @@ def _cast_to_object(tensor: torch.Tensor): return _tensor_to_object(tensor, tensor.numel() * tensor.element_size()) -def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]: +def _flatten_optim_state_dict(state_dict: dict, seperator: str = "-") -> Tuple[dict, Optional[dict]]: flat_dict = {} non_tensor_keys = [] if "state" in state_dict: @@ -87,7 +88,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."): state_dict = {} - if metadata is not None: + + if metadata is not None and "non_tensor_keys" in metadata: non_tensor_keys = json.loads(metadata["non_tensor_keys"]) else: non_tensor_keys = [] @@ -104,7 +106,11 @@ def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None for k, v in flat_dict.items(): parts = k.split(seperator) assert len(parts) == 3 and parts[0] == "state" - idx = int(parts[1]) + try: + idx = int(parts[1]) + except: + # exception for fsdp, part[1] isn't param_id + idx = parts[1] key = parts[2] if idx not in states: states[idx] = {} @@ -128,8 +134,10 @@ def prepare( header = {} offset = 0 + header_metadata = {"format": "pt"} if metadata is not None: - header["__metadata__"] = metadata + header_metadata.update(metadata) + header["__metadata__"] = header_metadata for name, tensor in data.items(): n = tensor.numel() * tensor.element_size() @@ -172,8 +180,9 @@ def move_and_save( path: str, state_dict: Dict[str, torch.Tensor], state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None, + metadata: Optional[Dict[str, str]] = None ) -> None: - prepared_data, _, tensor_keys = prepare(state_dict) + prepared_data, _, tensor_keys = prepare(state_dict, metadata) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys)) f_writer.write(n.to_bytes(8, byteorder="little")) @@ -188,9 +197,9 @@ def move_and_save( return f_writer -def load_flat(checkpoint_path): +def load_flat(checkpoint_path, seperator: str = "-"): with safe_open(checkpoint_path, framework="pt") as f: metadata = f.metadata() state_dict_load = load_file(checkpoint_path) - state_dict = _unflatten_optim_state_dict(state_dict_load, metadata) + state_dict = _unflatten_optim_state_dict(state_dict_load, metadata, seperator) return state_dict diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index a033e917baba..38315a523556 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -903,6 +903,7 @@ def state_dict_shard( keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True, + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. @@ -943,6 +944,11 @@ def state_dict_shard( gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) gathered_param = gathered_param_buffer.pop(param_to_save) + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(gathered_param, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(gathered_param) + gathered_param = pinned_state_dicts[prefix + name] block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -954,6 +960,11 @@ def state_dict_shard( for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -964,6 +975,11 @@ def state_dict_shard( is not torch.nn.Module.get_extra_state ): extra_state = self.get_extra_state() + if pinned_state_dicts is not None: + if extra_state_key not in pinned_state_dicts: + pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index ca91b4d9f27c..62074adc4612 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -809,7 +809,7 @@ def load_state_dict(self, state_dict: dict): self.optimizer_loading_epilogue() def state_shard( - self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True + self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True, pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing shards of optimizer states one by one. The max size of each dictionary shard is specified by ``max_shard_size``. @@ -829,6 +829,12 @@ def state_shard( dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + if pinned_state_dicts is not None: + pinned_state_dicts[param_id] = {} + for k, v in state.items(): + pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") + pinned_state_dicts[param_id][k].copy_(v) + state[k] = pinned_state_dicts[param_id][k] block, block_size = sharder.append_optim_state(param_id, state) if block is not None: yield block, block_size diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 8bee8fe97290..5a47f9a3680f 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -35,7 +35,8 @@ @parameterize("use_safetensors", [False, True]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int): +@parameterize("use_async", [False, True]) +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int, use_async: bool): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -70,7 +71,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b "", (model_size / 3), use_safetensors=use_safetensors, + use_async=use_async ) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict()) @@ -83,7 +87,8 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @parameterize("size_per_shard", [32]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int): +@parameterize("use_async", [False, True]) +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() enable_flash_attention = True if tp_size > 1 else False @@ -124,14 +129,22 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" + + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + booster.save_model( model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, + use_async=use_async ) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() booster.load_model(new_model, model_ckpt_path) @@ -155,8 +168,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha loss = criterion(output[output_key]) booster.backward(loss, new_optimizer) new_optimizer.step() - booster.save_model(new_model, model_ckpt_path, shard=shard) - booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + booster.save_model(new_model, model_ckpt_path, shard=shard, use_async=use_async) + booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() def exam_lazy_from_pretrained(): diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 8431036df6b7..25dfef0e25a8 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -19,7 +19,8 @@ @clear_cache_before_run() @parameterize("use_safetensors", [True, False]) -def test_unsharded_checkpoint(use_safetensors: bool): +@parameterize("use_async", [False, True]) +def test_unsharded_checkpoint(use_safetensors: bool, use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -36,18 +37,21 @@ def test_unsharded_checkpoint(use_safetensors: bool): lr_scheduler.step() # create a temp file for checkpoint - if use_safetensors: + if use_async or use_safetensors: suffix = ".safetensors" else: suffix = ".bin" model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) - optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + if use_async: + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) + else: + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() # save the model, optimizer, lr_scheduler ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors) - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors, use_async=use_async) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, use_async=use_async) ckpt_io.save_lr_scheduler(lr_scheduler, lr_scheduler_ckpt_tempfile.name) # create new model @@ -55,6 +59,9 @@ def test_unsharded_checkpoint(use_safetensors: bool): new_optimizer = Adam(new_model.parameters(), lr=0.001) new_lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) + ckpt_io._sync_d2h() + ckpt_io._sync_io() + # load the model, optimizer, lr_scheduler ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) @@ -66,7 +73,8 @@ def test_unsharded_checkpoint(use_safetensors: bool): @pytest.mark.parametrize("use_safetensors", [True, False]) -def test_sharded_model_checkpoint(use_safetensors: bool): +@pytest.mark.parametrize("use_async", [False, True]) +def test_sharded_model_checkpoint(use_safetensors: bool, use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -79,21 +87,18 @@ def test_sharded_model_checkpoint(use_safetensors: bool): loss.backward() optimizer.step() - # create a temp file for checkpoint - if use_safetensors: - pass - else: - pass - model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() # save the model and optimizer ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors, use_async=use_async) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) + ckpt_io._sync_d2h() + ckpt_io._sync_io() + # create new model new_model = resnet18() new_optimizer = Adam(new_model.parameters(), lr=0.001) @@ -106,7 +111,8 @@ def test_sharded_model_checkpoint(use_safetensors: bool): check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) -def test_sharded_optimizer_checkpoint(): +@pytest.mark.parametrize("use_async", [False, True]) +def test_sharded_optimizer_checkpoint(use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -128,7 +134,10 @@ def test_sharded_optimizer_checkpoint(): ckpt_io = GeneralCheckpointIO() ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async) + + ckpt_io._sync_d2h() + ckpt_io._sync_io() # create new model new_model = resnet18() @@ -148,9 +157,17 @@ def test_sharded_optimizer_checkpoint(): loss.backward() new_optimizer.step() + + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + # save the newly got optimizer ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) - ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async) + + ckpt_io._sync_d2h() + ckpt_io._sync_io() # create another new model new_new_model = resnet18() @@ -164,7 +181,8 @@ def test_sharded_optimizer_checkpoint(): check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict()) -def test_sharded_optimizer_multiple_param_groups(): +@pytest.mark.parametrize("use_async", [False, True]) +def test_sharded_optimizer_multiple_param_groups(use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam( @@ -188,7 +206,10 @@ def test_sharded_optimizer_multiple_param_groups(): ckpt_io = GeneralCheckpointIO() ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async) + + ckpt_io._sync_d2h() + ckpt_io._sync_io() # create new model new_model = resnet18() diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 87d35f2526b4..e35317425599 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -12,9 +12,10 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -@parameterize("shard", [True, False]) +@parameterize("shard", [False, True]) @parameterize("size_per_shard", [16, 128]) -def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): +@parameterize("use_async", [False, True]) +def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() @@ -39,9 +40,16 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = resnet18() diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index 12b70cc04d3c..b10a0692a39a 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -12,7 +12,7 @@ from colossalai.booster.plugin import TorchFSDPPlugin from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn, parameterize def compare_nested_dict(dict1, dict2): @@ -42,8 +42,8 @@ def compare_nested_dict(dict1, dict2): return False return True - -def check_torch_fsdp_ckpt(): +@parameterize("use_async", [False, True]) +def check_torch_fsdp_ckpt(use_async: bool): model = resnet18() plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) @@ -65,10 +65,17 @@ def run_model(): model_ckpt_path = f"{tempdir}/model" optim_ckpt_path = f"{tempdir}/optimizer" + if use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optim_ckpt_path = f"{optim_ckpt_path}.safetensors" + run_model() - booster.save_model(fsdp_model, model_ckpt_path, shard=False) - booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) + booster.save_model(fsdp_model, model_ckpt_path, shard=False, use_async=use_async) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=False, use_async=use_async) + + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() full_msd = fsdp_model.state_dict() # full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) @@ -106,8 +113,11 @@ def run_model(): run_model() - booster.save_model(fsdp_model, model_ckpt_path, shard=True) - booster.save_optimizer(optimizer, optim_ckpt_path, shard=True) + booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=False) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=True) + + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() full_msd = fsdp_model.unwrap().state_dict() full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer) From bb305f3abb0c1ca6c619fd3275f8992ca58e0aa8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 10:48:52 +0000 Subject: [PATCH 11/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/gemini_plugin.py | 22 ++++++++++------ .../booster/plugin/torch_fsdp_plugin.py | 26 ++++++++++++------- .../checkpoint_io/general_checkpoint_io.py | 17 +++++++----- colossalai/checkpoint_io/utils.py | 12 ++++++--- colossalai/utils/safetensors.py | 6 ++--- colossalai/zero/gemini/gemini_ddp.py | 6 +++-- colossalai/zero/gemini/gemini_optimizer.py | 6 ++++- .../test_gemini_checkpoint_io.py | 24 ++++++++--------- .../test_general_checkpoint_io.py | 5 ++-- .../test_torch_ddp_checkpoint_io.py | 4 ++- .../test_torch_fsdp_checkpoint_io.py | 3 ++- 11 files changed, 83 insertions(+), 48 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index b079abcb6b1c..7fb4d03a4fe8 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -17,24 +17,24 @@ from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( + async_save_state_dict_shards, + create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, load_shard_state_dict, save_config_file, save_state_dict, save_state_dict_shards, - async_save_state_dict_shards, - create_pinned_state_dict ) from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.utils.safetensors import load_flat from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats from .dp_plugin_base import DPPluginBase -from colossalai.utils.safetensors import load_flat __all__ = ["GeminiPlugin"] @@ -86,6 +86,7 @@ def save_unsharded_model( if self.coordinator.is_master(): if use_async: from colossalai.utils.safetensors import save + if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) for k, v in state_dict.items(): @@ -117,7 +118,8 @@ def save_unsharded_optimizer( state_dict = optimizer.state_dict() if self.coordinator.is_master(): if use_async: - from colossalai.utils.safetensors import save, _flatten_optim_state_dict + from colossalai.utils.safetensors import _flatten_optim_state_dict, save + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) @@ -157,14 +159,16 @@ def save_sharded_model( return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - + if use_async and self.coordinator.is_master(): if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = {} pinned_state_dicts = self.pinned_state_dicts[id(model)] else: pinned_state_dicts = None - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts) + state_dict_shard = model.state_dict_shard( + max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts + ) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) @@ -249,7 +253,9 @@ def save_sharded_optimizer( pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] else: pinned_state_dicts = None - state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts) + state_dict_shard = optimizer.state_shard( + prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts + ) # Save shards of optimizer states. if use_async: @@ -259,7 +265,7 @@ def save_sharded_optimizer( index_file=index_file, base_filename=states_name, is_master=True, - state_preprocess=True + state_preprocess=True, ) self.async_writers.extend(writers) else: diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index a969b6ea9de5..125c29370ad2 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -26,10 +26,10 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils +from colossalai.checkpoint_io.utils import async_save_state_dict_shards, create_pinned_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.checkpoint_io.utils import create_pinned_state_dict, async_save_state_dict_shards from colossalai.utils.safetensors import load_flat from .dp_plugin_base import DPPluginBase @@ -52,7 +52,7 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" if checkpoint.endswith(".safetensors"): - checkpoint = load_flat(checkpoint, seperator='-') + checkpoint = load_flat(checkpoint, seperator="-") else: checkpoint = utils.load_state_dict(checkpoint) fsdp_model = optimizer.unwrap_model() @@ -72,6 +72,7 @@ def save_unsharded_model( full_model_state = model.state_dict() if use_async: from colossalai.utils.safetensors import save + if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) for k, v in full_model_state.items(): @@ -92,8 +93,9 @@ def save_unsharded_optimizer( fsdp_model = optimizer.unwrap_model() full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) if use_async: - from colossalai.utils.safetensors import save, _flatten_optim_state_dict - flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator='-') + from colossalai.utils.safetensors import _flatten_optim_state_dict, save + + flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator="-") if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) for k, v in flatten_state_dict.items(): @@ -134,7 +136,9 @@ def save_sharded_model( pinned_state_dicts = self.pinned_state_dicts[id(model)] else: pinned_state_dicts = None - state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts) + state_dict_shard = utils.shard_model_checkpoint( + state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) @@ -231,7 +235,9 @@ def save_sharded_optimizer( if self.coordinator.is_master(): # Preparing file paths and index file. - states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix, use_safetensors=use_async) + states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames( + prefix, use_safetensors=use_async + ) index_file = CheckpointIndexFile(checkpoint) index_file.append_meta_data("param_groups", param_group_file) @@ -244,7 +250,9 @@ def save_sharded_optimizer( pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] else: pinned_state_dicts = None - sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts) + sharded_state = utils.shard_optimizer_checkpoint( + fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. if use_async: @@ -254,7 +262,7 @@ def save_sharded_optimizer( index_file=index_file, base_filename=states_name, is_master=True, - state_preprocess=True + state_preprocess=True, ) self.async_writers.extend(writers) else: @@ -298,7 +306,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: if shard_file.endswith(".safetensors"): - state_dict_shard = load_flat(shard_file, seperator='-') + state_dict_shard = load_flat(shard_file, seperator="-") else: state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) fsdp_optim_state.update(state_dict_shard) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 976475469658..f6bf1bb4a71d 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,6 +8,8 @@ import torch.nn as nn from torch.optim import Optimizer +from colossalai.utils.safetensors import load_flat + from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -30,8 +32,6 @@ sharded_optimizer_loading_epilogue, ) -from colossalai.utils.safetensors import load_flat - __all__ = ["GeneralCheckpointIO"] @@ -51,7 +51,7 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import move_and_save - + if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) @@ -135,7 +135,7 @@ def save_sharded_optimizer( base_filename=states_name, is_master=True, pinned_state_dict=pinned_state_dict, - state_preprocess=True + state_preprocess=True, ) self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict self.async_writers.extend(writers) @@ -175,12 +175,17 @@ def save_unsharded_optimizer( # TODO(FrankLeeeee): handle distributed tensors state_dict = optimizer.state_dict() if use_async: - from colossalai.utils.safetensors import move_and_save, _flatten_optim_state_dict + from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) - writer = move_and_save(path=checkpoint, state_dict=flatten_state_dict, state_dict_pinned=self.pinned_state_dicts[id(optimizer)], metadata=metadata) + writer = move_and_save( + path=checkpoint, + state_dict=flatten_state_dict, + state_dict_pinned=self.pinned_state_dicts[id(optimizer)], + metadata=metadata, + ) self.async_writers.append(writer) else: save_state_dict(state_dict, checkpoint, use_safetensors=False) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 107846737474..af0147b98ba9 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -307,7 +307,7 @@ def async_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator='-') + state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator="-") else: state_dict = shard @@ -391,7 +391,11 @@ def async_move_save_state_dict_shards( return total_size, returned_state_dict, writers -def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[int,Dict[str, torch.Tensor]]] = None) -> Iterator[Tuple[OrderedDict, int]]: +def shard_model_checkpoint( + state_dict: torch.Tensor, + max_shard_size: int = 1024, + pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, +) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -413,7 +417,9 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, yield state_dict_sharder.current_block, state_dict_sharder.current_block_size -def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]]=None) -> Iterator[Tuple[OrderedDict, int]]: +def shard_optimizer_checkpoint( + state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None +) -> Iterator[Tuple[OrderedDict, int]]: """ Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index e8d0f1f34904..9cc8e7fa61e4 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -4,8 +4,8 @@ from typing import Dict, List, Optional, Tuple import torch -from safetensors.torch import _TYPES, load_file, safe_open import torch.distributed +from safetensors.torch import _TYPES, load_file, safe_open try: from tensornvme.async_file_io import AsyncFileWriter @@ -109,7 +109,7 @@ def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None try: idx = int(parts[1]) except: - # exception for fsdp, part[1] isn't param_id + # exception for fsdp, part[1] isn't param_id idx = parts[1] key = parts[2] if idx not in states: @@ -180,7 +180,7 @@ def move_and_save( path: str, state_dict: Dict[str, torch.Tensor], state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None, - metadata: Optional[Dict[str, str]] = None + metadata: Optional[Dict[str, str]] = None, ) -> None: prepared_data, _, tensor_keys = prepare(state_dict, metadata) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 38315a523556..53ce34890d9d 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -903,7 +903,7 @@ def state_dict_shard( keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True, - pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. @@ -946,7 +946,9 @@ def state_dict_shard( if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(gathered_param, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name] = torch.empty_like( + gathered_param, pin_memory=True, device="cpu" + ) pinned_state_dicts[prefix + name].copy_(gathered_param) gathered_param = pinned_state_dicts[prefix + name] block, block_size = sharder.append_param(prefix + name, gathered_param) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 62074adc4612..c403515a498f 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -809,7 +809,11 @@ def load_state_dict(self, state_dict: dict): self.optimizer_loading_epilogue() def state_shard( - self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True, pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None + self, + prefix: str = "", + max_shard_size: int = 1024, + only_rank_0: bool = True, + pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing shards of optimizer states one by one. The max size of each dictionary shard is specified by ``max_shard_size``. diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 5a47f9a3680f..bb4f44a2a14b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -36,7 +36,9 @@ @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) @parameterize("use_async", [False, True]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int, use_async: bool): +def exam_state_dict_with_origin( + placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int, use_async: bool +): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -71,7 +73,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b "", (model_size / 3), use_safetensors=use_safetensors, - use_async=use_async + use_async=use_async, ) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io() @@ -88,7 +90,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) @parameterize("use_async", [False, True]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool): +def exam_state_dict( + placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool +): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() enable_flash_attention = True if tp_size > 1 else False @@ -134,15 +138,11 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha model_ckpt_path = f"{model_ckpt_path}.safetensors" optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" - booster.save_model( - model, - model_ckpt_path, - shard=shard, - size_per_shard=size_per_shard, - use_async=use_async - ) + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.save_optimizer( + optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io() dist.barrier() @@ -168,7 +168,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha loss = criterion(output[output_key]) booster.backward(loss, new_optimizer) new_optimizer.step() - + with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 25dfef0e25a8..327be0bb7d6f 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -93,7 +93,9 @@ def test_sharded_model_checkpoint(use_safetensors: bool, use_async: bool): # save the model and optimizer ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors, use_async=use_async) + ckpt_io.save_model( + model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors, use_async=use_async + ) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) ckpt_io._sync_d2h() @@ -157,7 +159,6 @@ def test_sharded_optimizer_checkpoint(use_async: bool): loss.backward() new_optimizer.step() - # create temp directories for checkpoint model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_dir = tempfile.TemporaryDirectory() diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index e35317425599..b7d746888ae3 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -46,7 +46,9 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bo optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.save_optimizer( + optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io() diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index b10a0692a39a..b362875d054d 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -12,7 +12,7 @@ from colossalai.booster.plugin import TorchFSDPPlugin from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from colossalai.testing import rerun_if_address_is_in_use, spawn, parameterize +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def compare_nested_dict(dict1, dict2): @@ -42,6 +42,7 @@ def compare_nested_dict(dict1, dict2): return False return True + @parameterize("use_async", [False, True]) def check_torch_fsdp_ckpt(use_async: bool): model = resnet18() From 0073f9768dab4de2e3f016a8e7dd3a3623ddce6a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Dec 2024 17:30:48 +0800 Subject: [PATCH 12/18] fix --- colossalai/utils/safetensors.py | 1 - tests/test_checkpoint_io/test_gemini_checkpoint_io.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 9cc8e7fa61e4..60ace5f57442 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Tuple import torch -import torch.distributed from safetensors.torch import _TYPES, load_file, safe_open try: diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index bb4f44a2a14b..a6d65cae5953 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -169,9 +169,9 @@ def exam_state_dict( booster.backward(loss, new_optimizer) new_optimizer.step() - with shared_tempdir() as tempdir: - model_ckpt_path = f"{tempdir}/model" - optimizer_ckpt_path = f"{tempdir}/optimizer" + with shared_tempdir() as new_tempdir: + model_ckpt_path = f"{new_tempdir}/model" + optimizer_ckpt_path = f"{new_tempdir}/optimizer" if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" From 3bafe2a434294d74013b361d80866e328c850ab5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 12 Dec 2024 18:04:10 +0800 Subject: [PATCH 13/18] fix --- colossalai/booster/plugin/gemini_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 7fb4d03a4fe8..441670a0aaea 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -180,7 +180,7 @@ def save_sharded_model( checkpoint=checkpoint_path, index_file=index_file, base_filename=weights_name, - is_master=True, + is_master=is_master, ) self.async_writers.extend(writers) else: @@ -264,7 +264,7 @@ def save_sharded_optimizer( checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=True, + is_master=self.coordinator.is_master(), state_preprocess=True, ) self.async_writers.extend(writers) From d0e2baacb5ef8c9ed6506dce8e3c1a809f93362e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 17 Dec 2024 15:28:59 +0800 Subject: [PATCH 14/18] support general param_id --- .../booster/plugin/hybrid_parallel_plugin.py | 2 +- .../plugin/moe_hybrid_parallel_plugin.py | 2 +- .../booster/plugin/torch_fsdp_plugin.py | 176 +++++++++++++++--- .../hybrid_parallel_checkpoint_io.py | 11 +- colossalai/checkpoint_io/moe_checkpoint.py | 3 +- colossalai/checkpoint_io/utils.py | 2 +- colossalai/utils/safetensors.py | 4 +- .../test_torch_fsdp_checkpoint_io.py | 4 +- 8 files changed, 162 insertions(+), 42 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 79c9379ccf1d..bc9425a0b0cd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1488,7 +1488,7 @@ def seed_worker(worker_id): ) def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert ( diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 96531a04fd69..12ee54f34106 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -404,7 +404,7 @@ def __init__( def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( - self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage + self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.ep_group, self.moe_dp_group, self.zero_stage ) def configure( diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 125c29370ad2..2bcd61db56db 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,6 +1,5 @@ import os -from pathlib import Path -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Any import torch import torch.nn as nn @@ -52,10 +51,40 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" if checkpoint.endswith(".safetensors"): - checkpoint = load_flat(checkpoint, seperator="-") + checkpoint = load_flat(checkpoint, seperator=".") else: checkpoint = utils.load_state_dict(checkpoint) + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False) + start_index = 0 + id2name = {} + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + start_num = len(id2name) + id2name.update( + { + i: p + for i, p in enumerate(group["params"], start_index) + if i not in id2name + } + ) + end_num = len(id2name) + start_index += end_num - start_num + + for g in full_optimizer_state["param_groups"]: + get_index_mapping(g) + + new_state = {} + for key, value in checkpoint["state"].items(): + new_state[id2name[int(key)]] = value + checkpoint["state"] = new_state + for g in checkpoint["param_groups"]: + new_group = [] + for param_id in g["params"]: + new_group.append(id2name[param_id]) + g["params"] = new_group + sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) @@ -70,18 +99,19 @@ def save_unsharded_model( cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): full_model_state = model.state_dict() - if use_async: - from colossalai.utils.safetensors import save - - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) - for k, v in full_model_state.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - full_model_state[k] = self.pinned_state_dicts[id(model)][k] - writer = save(checkpoint, full_model_state) - self.async_writers.append(writer) - else: - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + if self.coordinator.is_master(): + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) + for k, v in full_model_state.items(): + self.pinned_state_dicts[id(model)][k].copy_(v) + full_model_state[k] = self.pinned_state_dicts[id(model)][k] + writer = save(checkpoint, full_model_state) + self.async_writers.append(writer) + else: + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) def save_unsharded_optimizer( self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False @@ -91,20 +121,48 @@ def save_unsharded_optimizer( """ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) - if use_async: - from colossalai.utils.safetensors import _flatten_optim_state_dict, save - - flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator="-") - if id(optimizer) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) - for k, v in flatten_state_dict.items(): - self.pinned_state_dicts[id(optimizer)][k].copy_(v) - flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] - writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata) - self.async_writers.append(writer) - else: - utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) + + if self.coordinator.is_master(): + + # Save order indices instead of Tensors + name2id: Dict[str, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + name2id.update( + { + p: i + for i, p in enumerate(group["params"], start_index) + if p not in name2id + } + ) + packed["params"] = [name2id[p] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]] + full_optimizer_state["param_groups"] = param_groups + new_state = {} + for key, value in full_optimizer_state["state"].items(): + new_state[name2id[key]] = value + full_optimizer_state["state"] = new_state + + if use_async: + from colossalai.utils.safetensors import _flatten_optim_state_dict, save + flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".") + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[id(optimizer)][k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] + writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata) + self.async_writers.append(writer) + else: + utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) def save_sharded_model( self, @@ -150,7 +208,7 @@ def save_sharded_model( checkpoint=checkpoint_path, index_file=index_file, base_filename=weights_name, - is_master=True, + is_master=self.coordinator.is_master(), ) self.async_writers.extend(writers) else: @@ -234,6 +292,32 @@ def save_sharded_optimizer( ) if self.coordinator.is_master(): + + # Save order indices instead of Tensors + name2id: Dict[str, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + name2id.update( + { + p: i + for i, p in enumerate(group["params"], start_index) + if p not in name2id + } + ) + packed["params"] = [name2id[p] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]] + fsdp_optim_state["param_groups"] = param_groups + new_state = {} + for key, value in fsdp_optim_state["state"].items(): + new_state[name2id[key]] = value + fsdp_optim_state["state"] = new_state + # Preparing file paths and index file. states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames( prefix, use_safetensors=use_async @@ -261,7 +345,7 @@ def save_sharded_optimizer( checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=True, + is_master=self.coordinator.is_master(), state_preprocess=True, ) self.async_writers.extend(writers) @@ -306,13 +390,43 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: if shard_file.endswith(".safetensors"): - state_dict_shard = load_flat(shard_file, seperator="-") + state_dict_shard = load_flat(shard_file, seperator=".") else: state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) fsdp_optim_state.update(state_dict_shard) fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False) + start_index = 0 + id2name = {} + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + start_num = len(id2name) + id2name.update( + { + i: p + for i, p in enumerate(group["params"], start_index) + if i not in id2name + } + ) + end_num = len(id2name) + start_index += end_num - start_num + + for g in full_optimizer_state["param_groups"]: + get_index_mapping(g) + + new_state = {} + for key, value in fsdp_optim_dict["state"].items(): + new_state[id2name[int(key)]] = value + fsdp_optim_dict["state"] = new_state + for g in fsdp_optim_dict["param_groups"]: + new_group = [] + for param_id in g["params"]: + new_group.append(id2name[param_id]) + g["params"] = new_group + with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT): fsdp_state = FSDP.optim_state_dict_to_load( model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index d553294d0838..e7c203910e14 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -70,6 +70,7 @@ def __init__( dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, zero_stage: int, verbose: bool = True, ) -> None: @@ -77,9 +78,11 @@ def __init__( self.global_dp_group = dp_group self.pp_group = pp_group self.tp_group = tp_group + self.sp_group = sp_group self.dp_rank = dist.get_rank(self.global_dp_group) self.tp_rank = dist.get_rank(self.tp_group) self.pp_rank = dist.get_rank(self.pp_group) + self.sp_rank = dist.get_rank(self.sp_group) self.global_dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) @@ -490,7 +493,7 @@ def save_sharded_optimizer( # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. - control_saving = self.dp_rank == 0 and self.tp_rank == 0 + control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0 if use_async and control_saving: if id(optimizer) not in self.pinned_state_dicts: @@ -560,8 +563,10 @@ def save_sharded_optimizer( Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") - states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") + if not use_async: + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + else: + states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 3b07856ca06c..244f5bc0b644 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -44,12 +44,13 @@ def __init__( global_dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, ep_group: ProcessGroup, moe_dp_group: ProcessGroup, zero_stage: int, verbose: bool = True, ) -> None: - super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose) + super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose) self.global_dp_group = global_dp_group self.global_dp_rank = dist.get_rank(global_dp_group) self.global_dp_size = dist.get_world_size(global_dp_group) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index af0147b98ba9..11fa6d586e07 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -307,7 +307,7 @@ def async_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator="-") + state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".") else: state_dict = shard diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 60ace5f57442..b412953dd76e 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -59,7 +59,7 @@ def _cast_to_object(tensor: torch.Tensor): return _tensor_to_object(tensor, tensor.numel() * tensor.element_size()) -def _flatten_optim_state_dict(state_dict: dict, seperator: str = "-") -> Tuple[dict, Optional[dict]]: +def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]: flat_dict = {} non_tensor_keys = [] if "state" in state_dict: @@ -196,7 +196,7 @@ def move_and_save( return f_writer -def load_flat(checkpoint_path, seperator: str = "-"): +def load_flat(checkpoint_path, seperator: str = "."): with safe_open(checkpoint_path, framework="pt") as f: metadata = f.metadata() state_dict_load = load_file(checkpoint_path) diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index b362875d054d..25d901538064 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -114,8 +114,8 @@ def run_model(): run_model() - booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=False) - booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=True) + booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=use_async) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=use_async) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io() From 4e71953ad5310ec0b5cedcce7b847a5ca20b8317 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Dec 2024 07:41:27 +0000 Subject: [PATCH 15/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugin/moe_hybrid_parallel_plugin.py | 8 ++- .../booster/plugin/torch_fsdp_plugin.py | 53 ++++++------------- 2 files changed, 24 insertions(+), 37 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 12ee54f34106..6937b8d74ab9 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -404,7 +404,13 @@ def __init__( def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( - self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.ep_group, self.moe_dp_group, self.zero_stage + self.dp_group, + self.pp_group, + self.tp_group, + self.sp_group, + self.ep_group, + self.moe_dp_group, + self.zero_stage, ) def configure( diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 2bcd61db56db..08364cb5a0cb 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Any +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -59,22 +59,17 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False) start_index = 0 id2name = {} + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: nonlocal start_index start_num = len(id2name) - id2name.update( - { - i: p - for i, p in enumerate(group["params"], start_index) - if i not in id2name - } - ) + id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name}) end_num = len(id2name) start_index += end_num - start_num - + for g in full_optimizer_state["param_groups"]: get_index_mapping(g) - + new_state = {} for key, value in checkpoint["state"].items(): new_state[id2name[int(key)]] = value @@ -111,7 +106,9 @@ def save_unsharded_model( writer = save(checkpoint, full_model_state) self.async_writers.append(writer) else: - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + utils.save_state_dict( + full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors + ) def save_unsharded_optimizer( self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False @@ -133,17 +130,11 @@ def save_unsharded_optimizer( def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: nonlocal start_index packed = {k: v for k, v in group.items() if k != "params"} - name2id.update( - { - p: i - for i, p in enumerate(group["params"], start_index) - if p not in name2id - } - ) + name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id}) packed["params"] = [name2id[p] for p in group["params"]] start_index += len(packed["params"]) return packed - + param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]] full_optimizer_state["param_groups"] = param_groups new_state = {} @@ -153,6 +144,7 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: if use_async: from colossalai.utils.safetensors import _flatten_optim_state_dict, save + flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".") if id(optimizer) not in self.pinned_state_dicts: self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) @@ -300,17 +292,11 @@ def save_sharded_optimizer( def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: nonlocal start_index packed = {k: v for k, v in group.items() if k != "params"} - name2id.update( - { - p: i - for i, p in enumerate(group["params"], start_index) - if p not in name2id - } - ) + name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id}) packed["params"] = [name2id[p] for p in group["params"]] start_index += len(packed["params"]) return packed - + param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]] fsdp_optim_state["param_groups"] = param_groups new_state = {} @@ -401,22 +387,17 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False) start_index = 0 id2name = {} + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: nonlocal start_index start_num = len(id2name) - id2name.update( - { - i: p - for i, p in enumerate(group["params"], start_index) - if i not in id2name - } - ) + id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name}) end_num = len(id2name) start_index += end_num - start_num - + for g in full_optimizer_state["param_groups"]: get_index_mapping(g) - + new_state = {} for key, value in fsdp_optim_dict["state"].items(): new_state[id2name[int(key)]] = value From 83dc8befeca02091b65f1b0dc82db8cece051c89 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 17 Dec 2024 16:46:05 +0800 Subject: [PATCH 16/18] fix --- colossalai/booster/plugin/torch_fsdp_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 08364cb5a0cb..1d792757b9de 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple import torch From 115b74422ec2e59e50917cd98d274d0accac65a3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 18 Dec 2024 13:56:11 +0800 Subject: [PATCH 17/18] fix --- .../checkpoint_io/hybrid_parallel_checkpoint_io.py | 2 ++ colossalai/checkpoint_io/utils.py | 6 ++++-- colossalai/utils/safetensors.py | 2 ++ colossalai/zero/gemini/gemini_ddp.py | 12 ++++++------ colossalai/zero/gemini/gemini_optimizer.py | 8 ++++++-- .../test_torch_ddp_checkpoint_io.py | 2 +- 6 files changed, 21 insertions(+), 11 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index e7c203910e14..7123d7c8c122 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1017,6 +1017,8 @@ def gather_from_sharded_optimizer_state( state_ = state if inplace else copy.deepcopy(state) for k, v in state_.items(): + if v is None: + continue if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 11fa6d586e07..71422f4c2dcc 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -405,7 +405,8 @@ def shard_model_checkpoint( for key, weight in state_dict.items(): if not is_distributed_tensor(weight): if pinned_state_dicts is not None: - pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu") + if key not in pinned_state_dicts: + pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu") pinned_state_dicts[key].copy_(weight) weight = pinned_state_dicts[key] block, block_size = state_dict_sharder.append_param(key, weight) @@ -434,7 +435,8 @@ def shard_optimizer_checkpoint( if param_id not in pinned_state_dicts: pinned_state_dicts[param_id] = {} for k, v in state.items(): - pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") + if k not in pinned_state_dicts[param_id]: + pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") pinned_state_dicts[param_id][k].copy_(v) state[k] = pinned_state_dicts[param_id][k] diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index b412953dd76e..04bd414f171a 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -71,6 +71,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d for idx, d in states.items(): for k, v in d.items(): + if v is None: + continue nested_key = f"state{seperator}{idx}{seperator}{k}" if not isinstance(v, torch.Tensor): non_tensor_keys.append(nested_key) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 53ce34890d9d..9e89e88272e0 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -949,8 +949,8 @@ def state_dict_shard( pinned_state_dicts[prefix + name] = torch.empty_like( gathered_param, pin_memory=True, device="cpu" ) - pinned_state_dicts[prefix + name].copy_(gathered_param) - gathered_param = pinned_state_dicts[prefix + name] + pinned_state_dicts[prefix + name].copy_(gathered_param) + gathered_param = pinned_state_dicts[prefix + name] block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -965,8 +965,8 @@ def state_dict_shard( if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") - pinned_state_dicts[prefix + name].copy_(buffer) - buffer = pinned_state_dicts[prefix + name] + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -980,8 +980,8 @@ def state_dict_shard( if pinned_state_dicts is not None: if extra_state_key not in pinned_state_dicts: pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") - pinned_state_dicts[extra_state_key].copy_(extra_state) - extra_state = pinned_state_dicts[extra_state_key] + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index c403515a498f..def96b19b357 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -834,9 +834,13 @@ def state_shard( state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) if pinned_state_dicts is not None: - pinned_state_dicts[param_id] = {} + if param_id not in pinned_state_dicts: + pinned_state_dicts[param_id] = {} for k, v in state.items(): - pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") + if v is None: + continue + if k not in pinned_state_dicts[param_id]: + pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") pinned_state_dicts[param_id][k].copy_(v) state[k] = pinned_state_dicts[param_id][k] block, block_size = sharder.append_optim_state(param_id, state) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index b7d746888ae3..b90ea0960c8d 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -20,7 +20,7 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bo booster = Booster(plugin=plugin) model = resnet18() criterion = lambda x: x.mean() - optimizer = SGD((model.parameters()), lr=0.001) + optimizer = SGD((model.parameters()), lr=0.001, momentum=0.5) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler) From 6efcde54f5ee5d778ac663d29b56809c4239f6d5 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 20 Dec 2024 10:39:59 +0800 Subject: [PATCH 18/18] fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 2 +- colossalai/checkpoint_io/moe_checkpoint.py | 4 ++-- colossalai/utils/safetensors.py | 6 +----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 7123d7c8c122..0a2e598ca619 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -244,7 +244,7 @@ def save_sharded_model( # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - control_saving = self.tp_rank == 0 + control_saving = self.tp_rank == 0 and self.sp_rank == 0 if control_saving and use_async: if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = {} diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 244f5bc0b644..f6aefd33a9f5 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -159,7 +159,7 @@ def save_sharded_model( state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.tp_rank == 0 + control_saving = self.tp_rank == 0 and self.sp_rank == 0 if self.pp_size == 1 and self.ep_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO @@ -416,7 +416,7 @@ def save_sharded_optimizer( # e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather # rank 0 saves moe & non-moe params; rank 1 only saves moe params # rank 3 & 4 save nothing - control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 + control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 and self.sp_rank == 0 if self.pp_size == 1 and self.ep_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 04bd414f171a..8ce6d7335879 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -107,11 +107,7 @@ def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None for k, v in flat_dict.items(): parts = k.split(seperator) assert len(parts) == 3 and parts[0] == "state" - try: - idx = int(parts[1]) - except: - # exception for fsdp, part[1] isn't param_id - idx = parts[1] + idx = int(parts[1]) key = parts[2] if idx not in states: states[idx] = {}