From 98d0feb6dc383a3ce481f0699fa5061d7fda7c14 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 20 Nov 2024 13:21:02 +0800 Subject: [PATCH] fix --- .../checkpoint_io/hybrid_parallel_checkpoint_io.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 0cc5c4a7dcf1..1347e4da3a5b 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -720,16 +720,7 @@ def save_unsharded_model( # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: if use_async: - from tensornvme.async_file_io import AsyncFileWriter - - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - - f_writer = AsyncFileWriter( - fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread" - ) - move_and_save(f_writer, state_dict=state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)]) - self.async_writers.append(f_writer) + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) else: save_state_dict(state_dict, checkpoint, use_safetensors) else: