Skip to content

Commit

Permalink
Make test generic
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitgola005 committed Nov 5, 2024
1 parent f75b138 commit e0c826a
Showing 1 changed file with 36 additions and 28 deletions.
64 changes: 36 additions & 28 deletions tests/test_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,25 +811,17 @@ def test_hpu_fsdp_strategy_device_not_hpu(tmpdir):

@pytest.mark.standalone()
@pytest.mark.usefixtures("_check_distributed")
@pytest.mark.skipif(get_device_name_from_hlsmi() == "GAUDI", reason="Requires Gaudi2 and above.")
@pytest.mark.parametrize(
("strategy", "expected_memory"),
[
("NO_SHARD", 2.2),
("SHARD_GRAD_OP", 1.76),
("FULL_SHARD", 1.59),
],
)
def test_fsdp_sharding_strategy_memory(strategy, expected_memory):
"""Test FSDP memory with sharding strategies."""
seed_everything(42)
def test_fsdp_sharding_strategy_memory_comparison():
"""Test that FSDP memory usage follows expected pattern: FULL_SHARD <= SHARD_GRAD_OP <= NO_SHARD."""

class MemoryMonitorModule(BoringModel):
"""Module to monitor memory usage."""

class TestBoringModelMemory(TestBoringModel):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 512000), torch.nn.ReLU(), torch.nn.Linear(512000, 2))
self.memory = None
self.current_step = 0
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 1024000), torch.nn.ReLU(), torch.nn.Linear(1024000, 2))

def on_train_batch_start(self, batch, batch_idx):
htorch.hpu.reset_peak_memory_stats()
Expand All @@ -843,18 +835,34 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
torch.tensor(self.memory, device=torch.device("hpu")), reduce_op="sum"
)

model = TestBoringModelMemory()
trainer = Trainer(
accelerator=HPUAccelerator(),
devices=2,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * 2,
sharding_strategy=strategy,
precision_plugin=HPUFSDPPrecision("bf16-mixed"),
),
max_steps=2,
def measure_memory_usage(strategy_name):
"""Measure memory for a given sharding strategy."""
seed_everything(42)
model = MemoryMonitorModule()
trainer = Trainer(
accelerator=HPUAccelerator(),
devices=2,
strategy=HPUFSDPStrategy(
parallel_devices=[torch.device("hpu")] * 2,
sharding_strategy=strategy_name,
precision_plugin=HPUFSDPPrecision("bf16-mixed"),
),
max_steps=2,
limit_val_batches=0,
)

trainer.fit(model)
return model.memory

memory_usage = {
strategy: measure_memory_usage(strategy) for strategy in ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD"]
}

assert memory_usage["FULL_SHARD"] <= memory_usage["SHARD_GRAD_OP"], (
f"FULL_SHARD memory ({memory_usage['FULL_SHARD']:.2f} GB) should be less than or equal to "
f"SHARD_GRAD_OP memory ({memory_usage['SHARD_GRAD_OP']:.2f} GB)"
)
assert memory_usage["SHARD_GRAD_OP"] <= memory_usage["NO_SHARD"], (
f"SHARD_GRAD_OP memory ({memory_usage['SHARD_GRAD_OP']:.2f} GB) should be less than or equal to "
f"NO_SHARD memory ({memory_usage['NO_SHARD']:.2f} GB)"
)

trainer.fit(model)
print(f"{model.memory=}")
assert torch.allclose(model.memory, torch.tensor(expected_memory), atol=0.1, rtol=0.1)

0 comments on commit e0c826a

Please sign in to comment.