Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 20, 2024
1 parent a7082a4 commit 2755e92
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def save_sharded_optimizer(
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
shard_preprocess=True,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
Expand Down Expand Up @@ -544,6 +545,7 @@ def save_sharded_optimizer(
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
shard_preprocess=True,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
Expand Down
13 changes: 8 additions & 5 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def async_save_state_dict_shards(
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
use_pp_format: bool = False,
shard_preprocess: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Expand Down Expand Up @@ -313,16 +314,18 @@ def async_save_state_dict_shards(

writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
writers.append(writer)

flatten_dicts, _ = _flatten_optim_state_dict(shard)
if shard_preprocess:
state_dict, _ = _flatten_optim_state_dict(shard)
else:
state_dict = shard
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in flatten_dicts.keys()}
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(flatten_dicts)
sub_pinned_state_dict = create_pinned_state_dict(state_dict)
returned_state_dict.update(sub_pinned_state_dict)

# Only save on master rank.
move_and_save(writer, state_dict=flatten_dicts, state_dict_pinned=sub_pinned_state_dict)
move_and_save(writer, state_dict=state_dict, state_dict_pinned=sub_pinned_state_dict)
shard_filenames.append(shard_file)
del shard

Expand Down
2 changes: 1 addition & 1 deletion tests/test_checkpoint_io/test_safetensors_async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_save_load():
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
move_and_save(f_writer, model_state_dict_cuda, state_dict_pinned=model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
Expand Down

0 comments on commit 2755e92

Please sign in to comment.