diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 02a4e0b7051c..ce4d10322ba5 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -1,15 +1,12 @@ -import os - import pytest import torch import torch.distributed as dist -from transformers import LlamaForCausalLM +from torch.optim import Adam from utils import shared_tempdir import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin -from colossalai.lazy import LazyInitContext +from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin from colossalai.nn.optimizer import HybridAdam from colossalai.testing import ( check_state_dict_equal, @@ -20,105 +17,85 @@ ) from tests.kit.model_zoo import model_zoo -MODEL_PLACEMENT_CONFIGS = [ - {"placement_policy": "static", "shard_param_frac": 0.5}, -] - -OPTIM_PLACEMENT_CONFIGS = [ - {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half -] - @clear_cache_before_run() -@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) -@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) -@parameterize("use_safetensors", [False, True]) -@parameterize("tp_size", [1, 2]) -@parameterize("zero_size", [2]) -def exam_state_dict_with_origin( - placement_config, - model_name, - use_safetensors: bool, - tp_size: int, - zero_size: int, -): - from transformers import BertForSequenceClassification - +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) +def exam_torch_load_from_gemini(shard: bool, model_name: str): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) - bert_model = model_fn() + criterion = lambda x: x.mean() + plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14)) + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) - enable_flash_attention = True if tp_size > 1 else False - enable_fused_normalization = True if tp_size > 1 else False - enable_jit_fused = True if tp_size > 1 else False + booster.backward(loss, optimizer) + optimizer.step() with shared_tempdir() as tempdir: - pretrained_path = os.path.join(tempdir, "pretrained") - bert_model.config.save_pretrained(save_directory=pretrained_path) - - extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin( - **placement_config, - tp_size=tp_size, - enable_flash_attention=enable_flash_attention, - enable_fused_normalization=enable_fused_normalization, - enable_jit_fused=enable_jit_fused, - extra_dp_size=extra_dp_size, - ) - booster = Booster(plugin=plugin) - bert_model, _, _, _, _ = booster.boost(bert_model) - model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 - - booster.save_model( - bert_model, - pretrained_path, - True, - True, - "", - (model_size / 3), - use_safetensors=use_safetensors, - ) - booster.checkpoint_io._sync_d2h() - booster.checkpoint_io._sync_io() + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) dist.barrier() - new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict()) + + new_model = model_fn() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + new_plugin = TorchDDPPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading HybridAdam states to torch.Adam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. + check_state_dict_equal( + model.state_dict(only_rank_0=False, prefix="module.module."), + new_model.state_dict(), + ignore_device=False, + ignore_dtype=True, + ) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), ignore_device=False) + + # Check the new model/optimizer can successfully run. + data = data_gen_fn() + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + output = new_model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + new_booster.backward(loss, new_optimizer) + new_optimizer.step() + new_booster.save_model(new_model, model_ckpt_path, shard=shard) + new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) @clear_cache_before_run() -@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) @parameterize("shard", [False, True]) -@parameterize("model_name", ["transformers_llama_for_causal_lm"]) -@parameterize("size_per_shard", [32]) -@parameterize("tp_size", [1, 2]) -@parameterize("zero_size", [2]) -@parameterize("use_async", [False, True]) -def exam_state_dict( - placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool -): +@parameterize("model_name", ["transformers_gpt"]) +def exam_gemini_load_from_torch(shard: bool, model_name: str): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - enable_flash_attention = True if tp_size > 1 else False - enable_fused_normalization = True if tp_size > 1 else False - enable_jit_fused = True if tp_size > 1 else False - extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin( - **placement_config, - precision="fp16", - initial_scale=(2**14), - tp_size=tp_size, - extra_dp_size=extra_dp_size, - enable_flash_attention=enable_flash_attention, - enable_fused_normalization=enable_fused_normalization, - enable_jit_fused=enable_jit_fused, - ) + plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = model_fn() - new_model = model_fn() - optimizer = HybridAdam(model.parameters(), lr=0.001) + optimizer = Adam(model.parameters(), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_optimizer = HybridAdam(new_model.parameters(), lr=0.01) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) data = data_gen_fn() data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} @@ -129,32 +106,46 @@ def exam_state_dict( booster.backward(loss, optimizer) optimizer.step() - for group in optimizer.param_groups: - group["lr"] = 0.1 with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - if not use_async: - model_ckpt_path = f"{model_ckpt_path}.pt" - if use_async: - model_ckpt_path = f"{model_ckpt_path}.safetensors" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) - - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) - booster.checkpoint_io._sync_d2h() - booster.checkpoint_io._sync_io() + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) dist.barrier() - booster.load_model(new_model, model_ckpt_path) + new_model = model_fn() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_plugin = GeminiPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading torch.Adam states to HybridAdam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. check_state_dict_equal( - model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True + new_model.state_dict(only_rank_0=False, prefix="module.module."), + model.state_dict(), + ignore_device=False, + ignore_dtype=True, ) - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False)) - for group in new_optimizer.param_groups: - assert group["lr"] == 0.1 + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + old_state_dict = optimizer.state_dict() + new_state_dict = new_optimizer.state_dict(only_rank_0=False) + + # Comparison of param_groups needs special care here, + # since not all hyperparameters in Adam are used by HybridAdam + hyperparameters_to_examine = ["params", "lr", "betas", "eps", "weight_decay"] + for old_group, new_group in zip(old_state_dict["param_groups"], new_state_dict["param_groups"]): + for k in hyperparameters_to_examine: + assert ( + k in old_group and k in new_group + ), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" + assert old_group[k] == new_group[k] + check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], ignore_device=False) # Check the new model/optimizer can successfully run. data = data_gen_fn() @@ -165,41 +156,20 @@ def exam_state_dict( output = output_transform_fn(output) output_key = list(output.keys())[0] loss = criterion(output[output_key]) - booster.backward(loss, new_optimizer) + new_booster.backward(loss, new_optimizer) new_optimizer.step() - booster.save_model(new_model, model_ckpt_path, shard=shard) - booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) - - -def exam_lazy_from_pretrained(): - llama_path = os.environ["LLAMA_PATH"] - plugin = GeminiPlugin() - booster = Booster(plugin=plugin) - orig_model = LlamaForCausalLM.from_pretrained(llama_path) - orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()} - with LazyInitContext(): - model = LlamaForCausalLM.from_pretrained(llama_path) - model, *_ = booster.boost(model) - with shared_tempdir() as tempdir: - save_path = os.path.join(tempdir, "model.pt") - booster.save_model(model, save_path, shard=False) - dist.barrier() - state_dict = torch.load(save_path, map_location="cpu") - check_state_dict_equal(state_dict, orig_state_dict, ignore_dtype=True) + new_booster.save_model(new_model, model_ckpt_path, shard=shard) + new_booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - exam_state_dict() - # exam_state_dict_with_origin() - # exam_lazy_from_pretrained() + exam_torch_load_from_gemini() + exam_gemini_load_from_torch() @pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() -def test_gemini_ckpIO(): - spawn(run_dist, 4) - - -if __name__ == "__main__": - test_gemini_ckpIO() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size)