Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Nov 20, 2024
1 parent e324c8a commit 5c49004
Showing 1 changed file with 103 additions and 133 deletions.
236 changes: 103 additions & 133 deletions tests/test_checkpoint_io/test_gemini_torch_compability.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import os

import pytest
import torch
import torch.distributed as dist
from transformers import LlamaForCausalLM
from torch.optim import Adam
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.lazy import LazyInitContext
from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
Expand All @@ -20,105 +17,85 @@
)
from tests.kit.model_zoo import model_zoo

MODEL_PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.5},
]

OPTIM_PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half
]


@clear_cache_before_run()
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
def exam_state_dict_with_origin(
placement_config,
model_name,
use_safetensors: bool,
tp_size: int,
zero_size: int,
):
from transformers import BertForSequenceClassification

@parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
criterion = lambda x: x.mean()
plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14))
booster = Booster(plugin=plugin)

model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])

enable_flash_attention = True if tp_size > 1 else False
enable_fused_normalization = True if tp_size > 1 else False
enable_jit_fused = True if tp_size > 1 else False
booster.backward(loss, optimizer)
optimizer.step()

with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path)

extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(
**placement_config,
tp_size=tp_size,
enable_flash_attention=enable_flash_attention,
enable_fused_normalization=enable_fused_normalization,
enable_jit_fused=enable_jit_fused,
extra_dp_size=extra_dp_size,
)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2

booster.save_model(
bert_model,
pretrained_path,
True,
True,
"",
(model_size / 3),
use_safetensors=use_safetensors,
)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"

booster.save_model(model, model_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
dist.barrier()
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())

new_model = model_fn()
new_optimizer = Adam(new_model.parameters(), lr=0.001)
new_plugin = TorchDDPPlugin()
new_booster = Booster(plugin=new_plugin)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)

# Loading HybridAdam states to torch.Adam
new_booster.load_model(new_model, model_ckpt_path, strict=True)

# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
model.state_dict(only_rank_0=False, prefix="module.module."),
new_model.state_dict(),
ignore_device=False,
ignore_dtype=True,
)

new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), ignore_device=False)

# Check the new model/optimizer can successfully run.
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
output = new_model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
new_booster.backward(loss, new_optimizer)
new_optimizer.step()
new_booster.save_model(new_model, model_ckpt_path, shard=shard)
new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)


@clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])
@parameterize("use_async", [False, True])
def exam_state_dict(
placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool
):
@parameterize("model_name", ["transformers_gpt"])
def exam_gemini_load_from_torch(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
enable_flash_attention = True if tp_size > 1 else False
enable_fused_normalization = True if tp_size > 1 else False
enable_jit_fused = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin(
**placement_config,
precision="fp16",
initial_scale=(2**14),
tp_size=tp_size,
extra_dp_size=extra_dp_size,
enable_flash_attention=enable_flash_attention,
enable_fused_normalization=enable_fused_normalization,
enable_jit_fused=enable_jit_fused,
)
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)

model = model_fn()
new_model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=0.001)
optimizer = Adam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_optimizer = HybridAdam(new_model.parameters(), lr=0.01)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
Expand All @@ -129,32 +106,46 @@ def exam_state_dict(

booster.backward(loss, optimizer)
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1

with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
if not use_async:
model_ckpt_path = f"{model_ckpt_path}.pt"
if use_async:
model_ckpt_path = f"{model_ckpt_path}.safetensors"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async)

booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.checkpoint_io._sync_d2h()
booster.checkpoint_io._sync_io()

booster.save_model(model, model_ckpt_path, shard=shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
dist.barrier()

booster.load_model(new_model, model_ckpt_path)
new_model = model_fn()
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_plugin = GeminiPlugin()
new_booster = Booster(plugin=new_plugin)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)

# Loading torch.Adam states to HybridAdam
new_booster.load_model(new_model, model_ckpt_path, strict=True)

# Add prefix to get aligned with pytorch parameter names.
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
new_model.state_dict(only_rank_0=False, prefix="module.module."),
model.state_dict(),
ignore_device=False,
ignore_dtype=True,
)

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
old_state_dict = optimizer.state_dict()
new_state_dict = new_optimizer.state_dict(only_rank_0=False)

# Comparison of param_groups needs special care here,
# since not all hyperparameters in Adam are used by HybridAdam
hyperparameters_to_examine = ["params", "lr", "betas", "eps", "weight_decay"]
for old_group, new_group in zip(old_state_dict["param_groups"], new_state_dict["param_groups"]):
for k in hyperparameters_to_examine:
assert (
k in old_group and k in new_group
), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
assert old_group[k] == new_group[k]
check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], ignore_device=False)

# Check the new model/optimizer can successfully run.
data = data_gen_fn()
Expand All @@ -165,41 +156,20 @@ def exam_state_dict(
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, new_optimizer)
new_booster.backward(loss, new_optimizer)
new_optimizer.step()
booster.save_model(new_model, model_ckpt_path, shard=shard)
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)


def exam_lazy_from_pretrained():
llama_path = os.environ["LLAMA_PATH"]
plugin = GeminiPlugin()
booster = Booster(plugin=plugin)
orig_model = LlamaForCausalLM.from_pretrained(llama_path)
orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
with LazyInitContext():
model = LlamaForCausalLM.from_pretrained(llama_path)
model, *_ = booster.boost(model)
with shared_tempdir() as tempdir:
save_path = os.path.join(tempdir, "model.pt")
booster.save_model(model, save_path, shard=False)
dist.barrier()
state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, ignore_dtype=True)
new_booster.save_model(new_model, model_ckpt_path, shard=shard)
new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)


def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
# exam_state_dict_with_origin()
# exam_lazy_from_pretrained()
exam_torch_load_from_gemini()
exam_gemini_load_from_torch()


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO():
spawn(run_dist, 4)


if __name__ == "__main__":
test_gemini_ckpIO()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)

0 comments on commit 5c49004

Please sign in to comment.