Skip to content
This repository has been archived by the owner on Feb 1, 2024. It is now read-only.

bump Lightning 2.1+ #54

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements/lightning.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# this sets the requirements contains if you go with main lightning

# in 2.0.7 we have removed lightning.pytorch.overrides.base._LightningPrecisionModuleWrapperBase
lightning >=2.0.0, <=2.0.6

lightning >=2.1.0
2 changes: 1 addition & 1 deletion src/lightning_graphcore/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _convert_to_poptorch_loader(
return dataloader

dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
dataloader, sampler, mode, self.replication_factor > 1
dataloader, sampler, mode
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
return _reinstantiate_wrapped_cls(dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs)
Expand Down
4 changes: 1 addition & 3 deletions src/lightning_graphcore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
if package_available("lightning"):
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch import LightningModule
from lightning.pytorch.overrides.base import _LightningPrecisionModuleWrapperBase
elif package_available("pytorch_lightning"):
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from pytorch_lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase


class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, forward_module: Union[LightningModule, _LightningPrecisionModuleWrapperBase]) -> None:
def __init__(self, forward_module: LightningModule) -> None:
"""Wrap the user's LightningModule and redirect the forward call to the appropriate `*_step()` methods.

Inheriting classes may also modify the inputs or outputs of forward.
Expand Down
26 changes: 3 additions & 23 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,14 @@ def test_fail_if_no_ipus(_, tmpdir): # noqa: PT019
Trainer(default_root_dir=tmpdir, accelerator=IPUAccelerator(), devices=1)


@pytest.mark.xfail() # todo
def test_accelerator_selected(tmpdir):
assert IPUAccelerator.is_available()
trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1)
assert isinstance(trainer.accelerator, IPUAccelerator)


def test_warning_if_ipus_not_used():
with pytest.warns(UserWarning, match="IPU available but not used. Set `accelerator` and `devices`"):
with pytest.warns(UserWarning):
Trainer(accelerator="cpu")


Expand All @@ -72,10 +71,7 @@ def test_no_warning_strategy(tmpdir):
assert len(record) == 0


@pytest.mark.parametrize(
"devices",
[1, 4],
)
@pytest.mark.parametrize("devices",[1, 4])
def test_all_stages(tmpdir, devices):
model = IPUModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=IPUStrategy(), devices=devices)
Expand All @@ -85,18 +81,7 @@ def test_all_stages(tmpdir, devices):
trainer.predict(model)


@pytest.mark.parametrize(
"devices",
[
1,
pytest.param(
4,
marks=pytest.mark.xfail( # fixme
AssertionError, reason="Invalid batch dimension: In the input torch.Size([1, 32]), ..."
),
),
],
)
@pytest.mark.parametrize("devices",[1, 4])
def test_inference_only(tmpdir, devices):
model = IPUModel()

Expand Down Expand Up @@ -344,7 +329,6 @@ def test_clip_gradients_fails(tmpdir):
trainer.fit(model)


@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo
def test_autoreport(tmpdir):
"""Ensure autoreport dumps to a file."""
model = IPUModel()
Expand All @@ -361,7 +345,6 @@ def test_autoreport(tmpdir):
assert os.path.isfile(autoreport_path + "training/profile.pop")


@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo
def test_manual_poptorch_dataloader(tmpdir):
model_options = poptorch.Options()

Expand Down Expand Up @@ -393,7 +376,6 @@ def train_dataloader(self):
assert dataloader.drop_last # was kept


@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo
def test_manual_poptorch_opts(tmpdir):
"""Ensure if the user passes manual poptorch Options, we run with the correct object."""
model = IPUModel()
Expand Down Expand Up @@ -576,7 +558,6 @@ def test_accelerator_ipu_with_devices():
assert trainer.num_devices == 8


@pytest.mark.xfail(AssertionError, reason="not implemented on PL side")
def test_accelerator_auto_with_devices_ipu():
trainer = Trainer(accelerator="auto", devices=8)
assert isinstance(trainer.accelerator, IPUAccelerator)
Expand Down Expand Up @@ -621,7 +602,6 @@ def test_poptorch_models_at_different_stages(tmpdir):
assert list(trainer.strategy.poptorch_models) == [stage]


@pytest.mark.xfail(AssertionError, reason="not implemented on PL side")
def test_devices_auto_choice_ipu():
trainer = Trainer(accelerator="auto", devices="auto")
assert trainer.num_devices == 4
Expand Down