Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ckpt] Add async ckpt api for other plugin #6145

Closed
wants to merge 18 commits into from
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def save_unsharded_model(
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async, state_dict)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

Expand Down
15 changes: 10 additions & 5 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,16 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
model.load_state_dict(checkpoint, strict=strict)

def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
state_dict: dict = None,
):
state_dict = model.state_dict()

# TODO(FrankLeeeee): add support for gather_dtensor
if state_dict is None:
state_dict = model.state_dict()
if gather_dtensor:
pass

Expand All @@ -64,7 +69,7 @@ def save_unsharded_model(

else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
save_state_dict(model.state_dict(), checkpoint, use_safetensors)

def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Expand Down
70 changes: 41 additions & 29 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from tensornvme.async_file_io import AsyncFileWriter
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
Expand All @@ -22,6 +23,7 @@
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import move_and_save

from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
Expand Down Expand Up @@ -199,7 +201,6 @@ def save_sharded_model(
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
Expand All @@ -224,7 +225,18 @@ def save_sharded_model(
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
Expand All @@ -234,16 +246,16 @@ def save_sharded_model(
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

else:
# When pipeline is used, each stage produces its own shard files and index files.
Expand All @@ -259,24 +271,28 @@ def save_sharded_model(
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
n_write_entries=191,
)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)

if control_saving:
assert (
Expand Down Expand Up @@ -664,7 +680,6 @@ def save_unsharded_model(
model = model.unwrap()
if self.dp_rank != 0:
return

# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = model.state_dict()
Expand All @@ -686,15 +701,12 @@ def save_unsharded_model(
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)

Expand Down
30 changes: 21 additions & 9 deletions tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
@parameterize("use_safetensors", [False, True])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
def exam_state_dict_with_origin(
placement_config,
model_name,
use_safetensors: bool,
tp_size: int,
zero_size: int,
):
from transformers import BertForSequenceClassification

(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
Expand Down Expand Up @@ -71,19 +77,24 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
(model_size / 3),
use_safetensors=use_safetensors,
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())


@clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [True, False])
@parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
@parameterize("use_async", [False, True])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_flash_attention = True if tp_size > 1 else False
Expand Down Expand Up @@ -124,14 +135,15 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(
model,
model_ckpt_path,
shard=shard,
size_per_shard=size_per_shard,
)
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)

booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()

booster.load_model(new_model, model_ckpt_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
else:
TEST_CONFIGS = [
# TODO(ver217): other configs lead to hang
{"tp_size": 1, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
{"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},
]

Expand All @@ -42,8 +43,9 @@
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@parameterize("use_async", [True, False])
@clear_cache_before_run()
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
Expand Down Expand Up @@ -85,8 +87,14 @@ def _preprocess_data(data):
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
dist.barrier()

new_model = model_fn().cuda()
Expand Down
9 changes: 7 additions & 2 deletions tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@


@parameterize("shard", [True, False])
@parameterize("use_async", [True, False])
@parameterize("size_per_shard", [16, 128])
def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
def check_torch_ddp_checkpointIO(shard: bool, use_async: bool, size_per_shard: int):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
Expand All @@ -39,7 +40,11 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()
Expand Down
52 changes: 9 additions & 43 deletions tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

if version.parse(torch.__version__) >= version.parse("1.12.0"):
from colossalai.booster.plugin import TorchFSDPPlugin
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


def compare_nested_dict(dict1, dict2):
Expand Down Expand Up @@ -43,7 +42,9 @@ def compare_nested_dict(dict1, dict2):
return True


def check_torch_fsdp_ckpt():
@parameterize("shard", [True, False])
@parameterize("use_async", [True, False])
def check_torch_fsdp_ckpt(shard: bool, use_async: bool):
model = resnet18()
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)
Expand All @@ -64,10 +65,13 @@ def run_model():
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optim_ckpt_path = f"{tempdir}/optimizer"

if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
run_model()

booster.save_model(fsdp_model, model_ckpt_path, shard=False)
booster.save_model(fsdp_model, model_ckpt_path, shard=shard, use_async=use_async)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=False)

full_msd = fsdp_model.state_dict()
Expand Down Expand Up @@ -100,44 +104,6 @@ def run_model():
outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs)

with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optim_ckpt_path = f"{tempdir}/optimizer"

run_model()

booster.save_model(fsdp_model, model_ckpt_path, shard=True)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True)

full_msd = fsdp_model.unwrap().state_dict()
full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)

import copy

sharded_osd = copy.deepcopy(full_osd)

run_model()

full_msd_updated = fsdp_model.unwrap().state_dict()
full_osd_updated = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)

# cost much time led to timeout
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
# assert not compare_nested_dict(full_msd_updated, full_msd)
outputs_first = fsdp_model(inputs)
assert criterion(outputs_first) != criterion(outputs)

booster.load_model(fsdp_model, model_ckpt_path)
booster.load_optimizer(optimizer, optim_ckpt_path)

full_msd_restore = fsdp_model.unwrap().state_dict()
sharded_osd_restore = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)

assert compare_nested_dict(sharded_osd, sharded_osd_restore)
assert compare_nested_dict(full_msd_restore, full_msd)
outputs_sec = fsdp_model(inputs)
assert criterion(outputs_sec) == criterion(outputs)


def run_dist(rank, world_size, port):
# init dist env
Expand Down
Loading