Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ckpt] Add async ckpt api for other plugin #6145

Closed
wants to merge 18 commits into from
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def save_unsharded_model(
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async, state_dict)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

Expand Down
17 changes: 11 additions & 6 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
model.load_state_dict(checkpoint, strict=strict)

def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
state_dict: dict = None,
):
state_dict = model.state_dict()

# TODO(FrankLeeeee): add support for gather_dtensor
if state_dict is None:
state_dict = model.state_dict()
if gather_dtensor:
pass

Expand All @@ -60,11 +65,11 @@ def save_unsharded_model(
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
move_and_save(writer, state_dict, state_dict_pinned=self.pinned_state_dicts[id(model)])

else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
save_state_dict(model.state_dict(), checkpoint, use_safetensors)

def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Expand Down
72 changes: 43 additions & 29 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from tensornvme.async_file_io import AsyncFileWriter
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
Expand All @@ -22,6 +23,7 @@
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import save

from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
Expand Down Expand Up @@ -199,7 +201,6 @@ def save_sharded_model(
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
Expand All @@ -224,7 +225,18 @@ def save_sharded_model(
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
Expand All @@ -234,16 +246,16 @@ def save_sharded_model(
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

else:
# When pipeline is used, each stage produces its own shard files and index files.
Expand All @@ -259,24 +271,30 @@ def save_sharded_model(
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
move=False,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)

if control_saving:
assert (
Expand Down Expand Up @@ -664,7 +682,6 @@ def save_unsharded_model(
model = model.unwrap()
if self.dp_rank != 0:
return

# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict()
Expand All @@ -686,15 +703,12 @@ def save_unsharded_model(
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def async_save_state_dict_shards(
returned_state_dict.update(sub_pinned_state_dict)

# Only save on master rank.
move_and_save(writer, shard, sub_pinned_state_dict)
move_and_save(writer, shard, state_dict_pinned=sub_pinned_state_dict)
shard_filenames.append(shard_file)
del shard

Expand Down
3 changes: 2 additions & 1 deletion colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor])
def move_and_save(
f_writer: AsyncFileWriter,
state_dict: Dict[str, torch.Tensor],
metadata: Optional[Dict[str, str]] = None,
state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
prepared_data, _, tensor_keys = prepare(state_dict)
prepared_data, _, tensor_keys = prepare(state_dict, metadata)
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset

f_writer.write(n.to_bytes(8, byteorder="little"))
Expand Down
28 changes: 19 additions & 9 deletions tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
@parameterize("use_safetensors", [False, True])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
def exam_state_dict_with_origin(
placement_config,
model_name,
use_safetensors: bool,
tp_size: int,
zero_size: int,
):
from transformers import BertForSequenceClassification

(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
Expand Down Expand Up @@ -71,19 +77,24 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
(model_size / 3),
use_safetensors=use_safetensors,
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())


@clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [True, False])
@parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
@parameterize("use_async", [False, True])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_flash_attention = True if tp_size > 1 else False
Expand Down Expand Up @@ -124,14 +135,13 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
)
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)

booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()

booster.load_model(new_model, model_ckpt_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
else:
TEST_CONFIGS = [
# TODO(ver217): other configs lead to hang
{"tp_size": 1, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
{"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
]

Expand All @@ -42,8 +43,9 @@
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@parameterize("use_async", [True, False])
@clear_cache_before_run()
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
Expand Down Expand Up @@ -85,8 +87,12 @@ def _preprocess_data(data):
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()

new_model = model_fn().cuda()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us

model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
if not shard and not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if not shard and use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
if not shard and use_async:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_checkpoint_io/test_safetensors_async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def test_save_load():
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
move_and_save(
f_writer,
state_dict=model_state_dict_cuda,
state_dict_pinned=model_state_pinned,
)
f_writer.sync_before_step()
f_writer.synchronize()
del f_writer
Expand Down
9 changes: 7 additions & 2 deletions tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@


@parameterize("shard", [True, False])
@parameterize("use_async", [True, False])
@parameterize("size_per_shard", [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
def check_torch_ddp_checkpointIO(shard: bool, use_async: bool, size_per_shard: int):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
Expand All @@ -39,7 +40,11 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()
Expand Down
Loading
Loading