Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and flybird11111 committed Dec 17, 2024
1 parent d0e2baa commit 4e71953
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 37 deletions.
8 changes: 7 additions & 1 deletion colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,13 @@ def __init__(

def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.ep_group, self.moe_dp_group, self.zero_stage
self.dp_group,
self.pp_group,
self.tp_group,
self.sp_group,
self.ep_group,
self.moe_dp_group,
self.zero_stage,
)

def configure(
Expand Down
53 changes: 17 additions & 36 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Any
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -59,22 +59,17 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}

def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update(
{
i: p
for i, p in enumerate(group["params"], start_index)
if i not in id2name
}
)
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
end_num = len(id2name)
start_index += end_num - start_num

for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)

new_state = {}
for key, value in checkpoint["state"].items():
new_state[id2name[int(key)]] = value
Expand Down Expand Up @@ -111,7 +106,9 @@ def save_unsharded_model(
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
utils.save_state_dict(
full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors
)

def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
Expand All @@ -133,17 +130,11 @@ def save_unsharded_optimizer(
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update(
{
p: i
for i, p in enumerate(group["params"], start_index)
if p not in name2id
}
)
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed

param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
full_optimizer_state["param_groups"] = param_groups
new_state = {}
Expand All @@ -153,6 +144,7 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:

if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save

flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
Expand Down Expand Up @@ -300,17 +292,11 @@ def save_sharded_optimizer(
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update(
{
p: i
for i, p in enumerate(group["params"], start_index)
if p not in name2id
}
)
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed

param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
fsdp_optim_state["param_groups"] = param_groups
new_state = {}
Expand Down Expand Up @@ -401,22 +387,17 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, siz
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}

def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update(
{
i: p
for i, p in enumerate(group["params"], start_index)
if i not in id2name
}
)
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
end_num = len(id2name)
start_index += end_num - start_num

for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)

new_state = {}
for key, value in fsdp_optim_dict["state"].items():
new_state[id2name[int(key)]] = value
Expand Down

0 comments on commit 4e71953

Please sign in to comment.