From 5dc598d74336d0bd26c7077efca935a0226ffc09 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 20 Nov 2024 11:57:54 +0800 Subject: [PATCH] fix --- colossalai/checkpoint_io/general_checkpoint_io.py | 2 +- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 8f2f9f60c599..495cd4be3e47 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -60,7 +60,7 @@ def save_unsharded_model( if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.async_writers.append(writer) - move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + move_and_save(writer, state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)]) else: # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 438cdb4ac39d..0cc5c4a7dcf1 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -530,6 +530,7 @@ def save_sharded_optimizer( # Manage filenames of sharded weights and index file for each pipeline stage. states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file)