Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpointing tests #252

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .azure/hpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ jobs:
tests/test_pytorch/test_datamodule.py \
tests/test_pytorch/test_profiler.py \
tests/test_pytorch/test_precision.py \
tests/test_pytorch/test_checkpointing.py \
tests/test_pytorch/strategies/test_hpu_parallel.py \
tests/test_pytorch/strategies/test_hpu_ddp.py \
--hpus 1 -W ignore::FutureWarning -m "not standalone_only" \
Expand Down Expand Up @@ -155,6 +156,7 @@ jobs:
bash tests/run_standalone_tests.sh --hpus 1 -m standalone_only -f \
tests/test_pytorch/strategies/test_hpu_parallel.py \
tests/test_pytorch/test_precision.py \
tests/test_pytorch/test_checkpointing.py \
tests/test_pytorch/test_dynamic_shapes.py
displayName: Standalone-only single card tests

Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,9 @@ def device_count(pytestconfig):
return 1
assert arg_hpus <= HPUAccelerator.auto_device_count(), "More hpu devices asked than present"
return arg_hpus


@pytest.fixture()
def _check_distributed(device_count):
if device_count <= 1:
pytest.skip("Distributed test does not run on single HPU")
140 changes: 138 additions & 2 deletions tests/test_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel

import habana_frameworks.torch.hpu as hthpu
from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.plugins.fsdp_precision import HPUFSDPPrecision, HPUPrecisionPlugin
from lightning_habana.pytorch.strategies import HPUDDPStrategy, HPUFSDPStrategy
Expand Down Expand Up @@ -264,8 +265,8 @@ def test_fsdp_simple_model_activation_cp_mixed_precision(strategy, arg_hpus):


@pytest.mark.xfail(run=False, reason="To be fixed.Failure post 1.17 upgrade.")
@pytest.mark.skipif(HPUAccelerator.auto_device_count() <= 1, reason="Test requires multiple HPU devices.")
@pytest.mark.standalone()
@pytest.mark.usefixtures("_check_distributed")
def test_fsdp_strategy_simple_model_compile(tmpdir, arg_hpus):
"""Test to ensure that sync_batchnorm works when using FSDP and HPU."""
if arg_hpus <= 1:
Expand Down Expand Up @@ -664,7 +665,7 @@ def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params, arg_hpus):
trainer.strategy.barrier()


def test_dummy_fsdp_string_init(tmpdir):
def test_fsdp_dummy_string_init(tmpdir):
"""Test that TorchMetrics get moved to the device despite not having any parameters."""

class DummyFSDPStrategy(HPUFSDPStrategy):
Expand Down Expand Up @@ -806,3 +807,138 @@ def test_hpu_fsdp_strategy_device_not_hpu(tmpdir):
)
with pytest.raises(AssertionError, match="HPUFSDPStrategy requires HPUAccelerator"):
trainer.fit(BoringModel())


@pytest.mark.standalone()
@pytest.mark.parametrize(
("ckpt", "expected_memory"),
[
(True, 5679.0),
(False, 5674.25),
],
)
def test_hpu_fsdp_activation_checkpointing_memory_usage(tmpdir, ckpt, expected_memory):
"""Test memory usage difference with and without checkpointing."""

class TestMemoryModel(TestFSDPModel):
def _init_model(self) -> None:
self.layer = torch.nn.Sequential(
torch.nn.Linear(32, 32),
torch.nn.Linear(32, 32),
torch.nn.Linear(32, 2),
)
# Number of activations for Linear: out_features * batch_size(32)
# https://discuss.pytorch.org/t/number-of-activations-for-linear-and-conv2d-layer-comparison/48528/2
# Memory without checkpointing: (32 + 32 + 2) * 32 * 4 = 8.25KB
# Memory with checkpointing: (32 + 2) * 32 * 4 = 4.25KB
# Memory savings: 8.25-4.25 = 4KB (~ 5697KB - 5674.25KB)
# Note that these are estimated numbers, device may have other memory allocations.
self.peak_memory = 0
self.current_step = 0

def on_train_batch_start(self, batch, batch_idx):
if self.current_step == 1:
hthpu.reset_peak_memory_stats()

def on_train_batch_end(self, outputs, batch, batch_idx):
self.current_step += 1
if self.current_step <= 1:
return
self.peak_memory = hthpu.max_memory_allocated() / 1024

seed_everything(42)
model = TestMemoryModel()
dm = BoringDataModule()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=1,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")],
auto_wrap_policy={nn.Linear} if ckpt else None,
activation_checkpointing_policy={nn.Linear} if ckpt else None,
),
max_steps=2,
)
trainer.fit(model, dm)
assert torch.allclose(torch.tensor(model.peak_memory), torch.tensor(expected_memory), atol=1, rtol=1)


def test_hpu_fsdp_gradient_computation(tmpdir):
"""Test that gradients are computed correctly with checkpointing."""
grads = {}
for ckpt in [True, False]:
seed_everything(42)
model = TestFSDPModel()
dm = BoringDataModule()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=1,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")],
auto_wrap_policy={nn.Linear} if ckpt else None,
activation_checkpointing_policy={nn.Linear} if ckpt else None,
),
max_steps=1,
)
trainer.fit(model, dm)
_grads = {}
for name, param in model.named_parameters():
name = name.replace("._fsdp_wrapped_module._checkpoint_wrapped_module", "")
if param.grad is not None:
_grads[name] = param.grad.mean().item()
grads[f"{ckpt=}"] = _grads
assert grads["ckpt=True"].keys() == grads["ckpt=False"].keys()
for key in grads["ckpt=True"]:
assert grads["ckpt=True"][key] == grads["ckpt=False"][key]


@pytest.mark.standalone()
@pytest.mark.usefixtures("_check_distributed")
def test_hpu_fsdp_dist_checkpoint_save(tmpdir):
model = TestFSDPModel()

trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=2,
strategy=HPUFSDPStrategy(parallel_devices=[torch.device("hpu")] * 2, state_dict_type="sharded"),
max_steps=1,
)
trainer.fit(model)

if trainer.global_rank == 0:
checkpoint_dir = os.path.join(tmpdir, "lightning_logs", "version_0", "checkpoints", "epoch=0-step=1.ckpt")
for rank in range(2):
assert os.path.isfile(os.path.join(checkpoint_dir, f"__{rank}_0.distcp"))
assert os.path.getsize(os.path.join(checkpoint_dir, f"__{rank}_0.distcp")) > 0
trainer.strategy.barrier()


@pytest.mark.standalone()
@pytest.mark.usefixtures("_check_distributed")
def test_hpu_fsdp_dist_checkpoint_load(tmpdir):
model = TestFSDPModel()

# Save ckpts
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=2,
strategy=HPUFSDPStrategy(parallel_devices=[torch.device("hpu")] * 2, state_dict_type="sharded"),
max_steps=1,
)
trainer.fit(model)

# load and resume training from ckpt
trainer = Trainer(
default_root_dir=tmpdir,
accelerator=HPUAccelerator(),
devices=2,
strategy=HPUFSDPStrategy(parallel_devices=[torch.device("hpu")] * 2, state_dict_type="sharded"),
max_steps=1,
)
trainer.fit(
model, ckpt_path=os.path.join(tmpdir, "lightning_logs", "version_0", "checkpoints", "epoch=0-step=1.ckpt")
)
Loading
Loading