Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 20, 2024
1 parent 2755e92 commit 25fcd7b
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 61 deletions.
77 changes: 57 additions & 20 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.utils.safetensors import move_and_save
from colossalai.utils.safetensors import load_flat

from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_save_state_dict,
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
Expand Down Expand Up @@ -54,13 +54,16 @@ def save_unsharded_model(
pass

if use_async:
from tensornvme.async_file_io import AsyncFileWriter

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)])
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
new_pinned_state_dict, writers = async_save_state_dict(
state_dict,
checkpoint,
pinned_state_dict,
self.N_WRITE_ENTRIES,
shard_preprocess=False,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
Expand All @@ -85,7 +88,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)

sharded_optimizer_loading_epilogue(optimizer)
Expand Down Expand Up @@ -128,14 +134,29 @@ def save_sharded_optimizer(

# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
shard_preprocess=True,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)

# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
Expand All @@ -147,7 +168,10 @@ def save_sharded_optimizer(
)

def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)

def save_unsharded_optimizer(
Expand All @@ -158,7 +182,20 @@ def save_unsharded_optimizer(
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
if use_async:
state_dict = optimizer.state_dict()
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
new_pinned_state_dict, writers = async_save_state_dict(
state_dict,
checkpoint,
pinned_state_dict,
self.N_WRITE_ENTRIES,
shard_preprocess=True,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)

def save_sharded_model(
self,
Expand Down
79 changes: 38 additions & 41 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat, move_and_save
from colossalai.utils.safetensors import load_flat

from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
async_save_state_dict,
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
Expand Down Expand Up @@ -722,7 +722,16 @@ def save_unsharded_model(
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
new_pinned_state_dict, writers = async_save_state_dict(
state_dict,
checkpoint,
pinned_state_dict,
self.N_WRITE_ENTRIES,
shard_preprocess=False,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
Expand All @@ -736,15 +745,16 @@ def save_unsharded_model(
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
self.async_writers.append(writer)
move_and_save(
writer, state_dict=complete_state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)]
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
new_pinned_state_dict, writers = async_save_state_dict(
complete_state_dict,
checkpoint,
pinned_state_dict,
self.N_WRITE_ENTRIES,
shard_preprocess=False,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

Expand Down Expand Up @@ -830,22 +840,16 @@ def save_unsharded_optimizer(
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if use_async and id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
move_and_save(
f_writer,
state_dict=flatten_state_dict,
metadata=metadata,
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
new_pinned_state_dict, writers = async_save_state_dict(
state_dict,
checkpoint,
pinned_state_dict,
self.N_WRITE_ENTRIES,
shard_preprocess=True,
)
self.async_writers.append(f_writer)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
Expand All @@ -864,22 +868,15 @@ def save_unsharded_optimizer(
for _states in states_list:
state_dict["state"].update(_states)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)

f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
move_and_save(
f_writer,
state_dict=flatten_state_dict,
metadata=metadata,
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
new_pinned_state_dict, writers = async_save_state_dict(
state_dict,
checkpoint,
pinned_state_dict,
self.N_WRITE_ENTRIES,
shard_preprocess=True,
)
self.async_writers.append(f_writer)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

Expand Down
28 changes: 28 additions & 0 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,34 @@ def save_state_dict(
torch.save(state_dict_cpu, checkpoint_file_path)


def async_save_state_dict(
state_dict: dict,
checkpoint_file_path: str,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
shard_preprocess: bool = False,
):
from tensornvme.async_file_io import AsyncFileWriter

async_writers = []
if shard_preprocess:
saved_state_dict, metadata = _flatten_optim_state_dict(state_dict)
else:
saved_state_dict, metadata = state_dict, None
if pinned_state_dict is None:
pinned_state_dict = create_pinned_state_dict(saved_state_dict)

f_writer = AsyncFileWriter(fp=open(checkpoint_file_path, "wb"), n_entries=n_write_entries, backend="pthread")
move_and_save(
f_writer,
state_dict=saved_state_dict,
metadata=metadata,
state_dict_pinned=pinned_state_dict,
)
async_writers.append(f_writer)
return pinned_state_dict, async_writers


def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Expand Down

0 comments on commit 25fcd7b

Please sign in to comment.