Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[checkpointio]support asyncio for all models #6152

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
113 changes: 85 additions & 28 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
Expand All @@ -28,6 +30,7 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

Expand Down Expand Up @@ -82,7 +85,15 @@ def save_unsharded_model(
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for k, v in state_dict.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

Expand All @@ -106,7 +117,19 @@ def save_unsharded_optimizer(
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save

flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, flatten_state_dict, metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
Expand Down Expand Up @@ -137,17 +160,29 @@ def save_sharded_model(

Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = model.state_dict_shard(
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)

# Save shards of optimizer states.
is_master = self.coordinator.is_master()
if use_async:
super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
)

self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
Expand All @@ -158,17 +193,17 @@ def save_sharded_model(
use_safetensors=use_safetensors,
)

# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.",
ranks=[0],
)

def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
Expand Down Expand Up @@ -201,7 +236,7 @@ def save_sharded_optimizer(
Path(checkpoint).mkdir(parents=True, exist_ok=True)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)

Expand All @@ -212,17 +247,36 @@ def save_sharded_optimizer(
torch.save(param_groups, group_file_path)

# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict_shard = optimizer.state_shard(
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)

# Save shards of optimizer states.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)

# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
Expand Down Expand Up @@ -264,7 +318,10 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ def seed_worker(worker_id):
)

def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)

def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
Expand Down
8 changes: 7 additions & 1 deletion colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,13 @@ def __init__(

def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
self.dp_group,
self.pp_group,
self.tp_group,
self.sp_group,
self.ep_group,
self.moe_dp_group,
self.zero_stage,
)

def configure(
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def save_unsharded_optimizer(
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Expand Down
Loading
Loading