Skip to content

Commit

Permalink
support general param_id
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Dec 17, 2024
1 parent 611dcb5 commit 8b1f649
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 42 deletions.
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ def seed_worker(worker_id):
)

def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)

def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def __init__(

def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.ep_group, self.moe_dp_group, self.zero_stage
)

def configure(
Expand Down
176 changes: 145 additions & 31 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Any

import torch
import torch.nn as nn
Expand Down Expand Up @@ -52,10 +51,40 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint, seperator="-")
checkpoint = load_flat(checkpoint, seperator=".")
else:
checkpoint = utils.load_state_dict(checkpoint)

fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update(
{
i: p
for i, p in enumerate(group["params"], start_index)
if i not in id2name
}
)
end_num = len(id2name)
start_index += end_num - start_num

for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)

new_state = {}
for key, value in checkpoint["state"].items():
new_state[id2name[int(key)]] = value
checkpoint["state"] = new_state
for g in checkpoint["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group

sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)

Expand All @@ -70,18 +99,19 @@ def save_unsharded_model(
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
if self.coordinator.is_master():
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)

def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
Expand All @@ -91,20 +121,48 @@ def save_unsharded_optimizer(
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()

full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save

flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator="-")
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

if self.coordinator.is_master():

# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0

def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update(
{
p: i
for i, p in enumerate(group["params"], start_index)
if p not in name2id
}
)
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed

param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
full_optimizer_state["param_groups"] = param_groups
new_state = {}
for key, value in full_optimizer_state["state"].items():
new_state[name2id[key]] = value
full_optimizer_state["state"] = new_state

if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

def save_sharded_model(
self,
Expand Down Expand Up @@ -150,7 +208,7 @@ def save_sharded_model(
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
is_master=self.coordinator.is_master(),
)
self.async_writers.extend(writers)
else:
Expand Down Expand Up @@ -234,6 +292,32 @@ def save_sharded_optimizer(
)

if self.coordinator.is_master():

# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0

def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update(
{
p: i
for i, p in enumerate(group["params"], start_index)
if p not in name2id
}
)
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed

param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
fsdp_optim_state["param_groups"] = param_groups
new_state = {}
for key, value in fsdp_optim_state["state"].items():
new_state[name2id[key]] = value
fsdp_optim_state["state"] = new_state

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(
prefix, use_safetensors=use_async
Expand Down Expand Up @@ -261,7 +345,7 @@ def save_sharded_optimizer(
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
Expand Down Expand Up @@ -306,13 +390,43 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file, seperator="-")
state_dict_shard = load_flat(shard_file, seperator=".")
else:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
fsdp_optim_state.update(state_dict_shard)

fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)

fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update(
{
i: p
for i, p in enumerate(group["params"], start_index)
if i not in id2name
}
)
end_num = len(id2name)
start_index += end_num - start_num

for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)

new_state = {}
for key, value in fsdp_optim_dict["state"].items():
new_state[id2name[int(key)]] = value
fsdp_optim_dict["state"] = new_state
for g in fsdp_optim_dict["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group

with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
fsdp_state = FSDP.optim_state_dict_to_load(
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
Expand Down
11 changes: 8 additions & 3 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,19 @@ def __init__(
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__()
self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.sp_rank = dist.get_rank(self.sp_group)
self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
Expand Down Expand Up @@ -490,7 +493,7 @@ def save_sharded_optimizer(

# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
control_saving = self.dp_rank == 0 and self.tp_rank == 0
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0

if use_async and control_saving:
if id(optimizer) not in self.pinned_state_dicts:
Expand Down Expand Up @@ -560,8 +563,10 @@ def save_sharded_optimizer(
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)

# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
if not use_async:
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
else:
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)

Expand Down
3 changes: 2 additions & 1 deletion colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def __init__(
global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose)
self.global_dp_group = global_dp_group
self.global_dp_rank = dist.get_rank(global_dp_group)
self.global_dp_size = dist.get_world_size(global_dp_group)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def async_save_state_dict_shards(
checkpoint_file_path = os.path.join(checkpoint, shard_file)

if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator="-")
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
else:
state_dict = shard

Expand Down
4 changes: 2 additions & 2 deletions colossalai/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _cast_to_object(tensor: torch.Tensor):
return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())


def _flatten_optim_state_dict(state_dict: dict, seperator: str = "-") -> Tuple[dict, Optional[dict]]:
def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
flat_dict = {}
non_tensor_keys = []
if "state" in state_dict:
Expand Down Expand Up @@ -196,7 +196,7 @@ def move_and_save(
return f_writer


def load_flat(checkpoint_path, seperator: str = "-"):
def load_flat(checkpoint_path, seperator: str = "."):
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata()
state_dict_load = load_file(checkpoint_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def run_model():

run_model()

booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=False)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=True)
booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=use_async)
booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=use_async)

booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
Expand Down

0 comments on commit 8b1f649

Please sign in to comment.