Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 20, 2024
1 parent 8df96c2 commit a5eb3ea
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,16 +720,7 @@ def save_unsharded_model(
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
move_and_save(f_writer, state_dict=state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)])
self.async_writers.append(f_writer)
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
Expand Down

0 comments on commit a5eb3ea

Please sign in to comment.