Skip to content
This repository has been archived by the owner on Feb 1, 2024. It is now read-only.

bump Lightning 2.1+ #54

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements/lightning.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# this sets the requirements contains if you go with main lightning

# in 2.0.7 we have removed lightning.pytorch.overrides.base._LightningPrecisionModuleWrapperBase
lightning >=2.0.0, <2.2.0
4 changes: 1 addition & 3 deletions src/lightning_graphcore/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,7 @@ def _convert_to_poptorch_loader(
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
return dataloader

dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
dataloader, sampler, mode, self.replication_factor > 1
)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
return _reinstantiate_wrapped_cls(dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs)

Expand Down
4 changes: 1 addition & 3 deletions src/lightning_graphcore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
if package_available("lightning"):
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch import LightningModule
from lightning.pytorch.overrides.base import _LightningPrecisionModuleWrapperBase
elif package_available("pytorch_lightning"):
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from pytorch_lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase


class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, forward_module: Union[LightningModule, _LightningPrecisionModuleWrapperBase]) -> None:
def __init__(self, forward_module: LightningModule) -> None:
"""Wrap the user's LightningModule and redirect the forward call to the appropriate `*_step()` methods.

Inheriting classes may also modify the inputs or outputs of forward.
Expand Down
21 changes: 2 additions & 19 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,14 @@ def test_fail_if_no_ipus(_, tmpdir): # noqa: PT019
Trainer(default_root_dir=tmpdir, accelerator=IPUAccelerator(), devices=1)


@pytest.mark.xfail() # todo
def test_accelerator_selected(tmpdir):
assert IPUAccelerator.is_available()
trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1)
assert isinstance(trainer.accelerator, IPUAccelerator)


def test_warning_if_ipus_not_used():
with pytest.warns(UserWarning, match="IPU available but not used. Set `accelerator` and `devices`"):
with pytest.warns(UserWarning):
Trainer(accelerator="cpu")


Expand All @@ -82,18 +81,7 @@ def test_all_stages(tmpdir, devices):
trainer.predict(model)


@pytest.mark.parametrize(
"devices",
[
1,
pytest.param(
4,
marks=pytest.mark.xfail( # fixme
AssertionError, reason="Invalid batch dimension: In the input torch.Size([1, 32]), ..."
),
),
],
)
@pytest.mark.parametrize("devices", [1, 4])
def test_inference_only(tmpdir, devices):
model = IPUModel()

Expand Down Expand Up @@ -341,7 +329,6 @@ def test_clip_gradients_fails(tmpdir):
trainer.fit(model)


@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo
def test_autoreport(tmpdir):
"""Ensure autoreport dumps to a file."""
model = IPUModel()
Expand All @@ -358,7 +345,6 @@ def test_autoreport(tmpdir):
assert os.path.isfile(autoreport_path + "training/profile.pop")


@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo
def test_manual_poptorch_dataloader(tmpdir):
model_options = poptorch.Options()

Expand Down Expand Up @@ -390,7 +376,6 @@ def train_dataloader(self):
assert dataloader.drop_last # was kept


@pytest.mark.xfail(RuntimeError, reason="element 0 of tensors does not require grad and does not have ...") # todo
def test_manual_poptorch_opts(tmpdir):
"""Ensure if the user passes manual poptorch Options, we run with the correct object."""
model = IPUModel()
Expand Down Expand Up @@ -573,7 +558,6 @@ def test_accelerator_ipu_with_devices():
assert trainer.num_devices == 8


@pytest.mark.xfail(AssertionError, reason="not implemented on PL side")
def test_accelerator_auto_with_devices_ipu():
trainer = Trainer(accelerator="auto", devices=8)
assert isinstance(trainer.accelerator, IPUAccelerator)
Expand Down Expand Up @@ -618,7 +602,6 @@ def test_poptorch_models_at_different_stages(tmpdir):
assert list(trainer.strategy.poptorch_models) == [stage]


@pytest.mark.xfail(AssertionError, reason="not implemented on PL side")
def test_devices_auto_choice_ipu():
trainer = Trainer(accelerator="auto", devices="auto")
assert trainer.num_devices == 4
Expand Down
Loading