Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace _init_optim_state w/ tnt's util #802

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions tests/framework/callbacks/test_module_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@

from torchtnt.framework.callbacks.module_summary import ModuleSummary
from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.utils.version import is_torch_version_geq_1_13

MODULE_SUMMARY_FLOPS_AVAILABLE = False
if is_torch_version_geq_1_13():
MODULE_SUMMARY_FLOPS_AVAILABLE = True


class ModuleSummaryTest(unittest.TestCase):
Expand Down Expand Up @@ -85,10 +80,6 @@ def forward(self, x):
self.assertTrue("b1" in ms.submodule_summaries)
self.assertTrue("l2" in ms.submodule_summaries)

@unittest.skipUnless(
condition=MODULE_SUMMARY_FLOPS_AVAILABLE,
reason="This test needs PyTorch 1.13 or greater to run.",
)
def test_module_summary_retrieve_module_summaries_module_inputs(self) -> None:
"""
Test ModuleSummary callback in train
Expand Down
11 changes: 2 additions & 9 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@
from unittest.mock import MagicMock, patch

import torch
from torchtnt.framework.auto_unit import TrainStepResults
from torchtnt.utils.test_utils import skip_if_not_distributed

from torchtnt.utils.version import is_torch_version_geq_1_13

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo

from pyre_extensions import none_throws, ParameterSpecification as ParamSpec

Expand All @@ -37,6 +28,7 @@
AutoUnit,
SWALRParams,
SWAParams,
TrainStepResults,
)
from torchtnt.framework.evaluate import evaluate
from torchtnt.framework.predict import predict
Expand All @@ -49,6 +41,7 @@
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy
from torchtnt.utils.progress import Progress
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.timer import Timer

TParams = ParamSpec("TParams")
Expand Down
19 changes: 4 additions & 15 deletions tests/framework/test_auto_unit_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,16 @@
# pyre-strict

import unittest

from copy import deepcopy
from typing import TypeVar
from unittest.mock import MagicMock, patch

import torch
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu

from torchtnt.utils.version import is_torch_version_geq_1_13

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo

from copy import deepcopy

from pyre_extensions import ParameterSpecification as ParamSpec
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torchtnt.framework._test_utils import (
DummyAutoUnit,
generate_random_dataloader,
Expand All @@ -40,6 +32,7 @@
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env, seed
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu

TParams = ParamSpec("TParams")
T = TypeVar("T")
Expand Down Expand Up @@ -320,10 +313,6 @@ def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
device_type="cuda", dtype=torch.float16, enabled=True
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
@skip_if_not_gpu
@patch("torch.compile")
def test_compile_predict(self, mock_dynamo: MagicMock) -> None:
Expand Down
8 changes: 0 additions & 8 deletions tests/utils/test_memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,9 @@
MemorySnapshotParams,
MemorySnapshotProfiler,
)
from torchtnt.utils.version import is_torch_version_geq_2_0


class MemorySnapshotProfilerTest(unittest.TestCase):

torch_version_geq_2_0: bool = is_torch_version_geq_2_0()

@unittest.skipUnless(
condition=torch_version_geq_2_0,
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_validation(self) -> None:
"""Test parameter validation."""
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
8 changes: 0 additions & 8 deletions tests/utils/test_memory_snapshot_profiler_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,10 @@
MemorySnapshotProfiler,
)
from torchtnt.utils.test_utils import skip_if_not_gpu
from torchtnt.utils.version import is_torch_version_geq_2_0


class MemorySnapshotProfilerGPUTest(unittest.TestCase):

torch_version_geq_2_0: bool = is_torch_version_geq_2_0()

@skip_if_not_gpu
@unittest.skipUnless(
condition=torch_version_geq_2_0,
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_stop_step(self) -> None:
"""Test that a memory snapshot is saved when stop_step is reached."""
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
5 changes: 0 additions & 5 deletions tests/utils/test_oom_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@
from torchtnt.utils.oom import log_memory_snapshot

from torchtnt.utils.test_utils import skip_if_not_gpu
from torchtnt.utils.version import is_torch_version_geq_2_0


class OomGPUTest(unittest.TestCase):
@skip_if_not_gpu
@unittest.skipUnless(
condition=bool(is_torch_version_geq_2_0()),
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_log_memory_snapshot(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
# Record history
Expand Down
27 changes: 6 additions & 21 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@
TorchCompileParams,
)
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.version import is_torch_version_geq_1_13, Version

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo
from torchtnt.utils.version import is_torch_version_geq


class PrepareModelTest(unittest.TestCase):
torch_version_geq_2_1_0: bool = is_torch_version_geq("2.1.0")

def test_invalid_fsdp_strategy_str_values(self) -> None:
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision

Expand Down Expand Up @@ -149,7 +146,7 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No

tc = unittest.TestCase()
with patch(
"torchtnt.utils.version.get_torch_version", return_value=Version("2.0.0")
"torchtnt.utils.prepare_module.is_torch_version_geq", return_value=False
):
with tc.assertRaisesRegex(
RuntimeError,
Expand All @@ -162,18 +159,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
torch_compile_params=TorchCompileParams(backend="inductor"),
)

# no error should be thrown on latest pytorch
prepare_module(
module=torch.nn.Linear(2, 2),
device=init_from_env(),
strategy=DDPStrategy(static_graph=True),
torch_compile_params=TorchCompileParams(backend="inductor"),
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
def test_prepare_module_compile_invalid_backend(self) -> None:
"""
verify error is thrown on invalid backend
Expand All @@ -200,8 +185,8 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
torch_version_geq_2_1_0,
reason="Must be on torch 2.1.0+ to run test",
)
def test_prepare_module_compile_module_state_dict(self) -> None:
device = init_from_env()
Expand Down
44 changes: 4 additions & 40 deletions tests/utils/test_prepare_module_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

# pyre-strict
import unittest
from unittest.mock import patch

import torch

from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -24,15 +25,6 @@
prepare_module,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
from torchtnt.utils.version import is_torch_version_geq_1_13, is_torch_version_geq_2_0

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo

if is_torch_version_geq_2_0():
from torch.distributed._composable import fully_shard


class PrepareModelGPUTest(unittest.TestCase):
Expand Down Expand Up @@ -81,33 +73,6 @@ def _test_prepare_fsdp() -> None:
tc = unittest.TestCase()
tc.assertTrue(isinstance(fsdp_module, FSDP))

@skip_if_not_distributed
@skip_if_not_gpu
def test_fsdp_pytorch_version(self) -> None:
"""
Test that a RuntimeError is thrown when using FSDP, and PyTorch < v1.12
"""
spawn_multi_process(
2,
"nccl",
self._test_fsdp_pytorch_version,
)

@staticmethod
def _test_fsdp_pytorch_version() -> None:
device = init_from_env()
module = torch.nn.Linear(2, 2).to(device)

tc = unittest.TestCase()
with patch(
"torchtnt.utils.prepare_module.is_torch_version_geq_1_12",
return_value=False,
), tc.assertRaisesRegex(
RuntimeError,
"Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/",
):
_ = prepare_fsdp(module, device, FSDPStrategy())

@skip_if_not_distributed
@unittest.skipUnless(
condition=bool(torch.cuda.device_count() >= 2),
Expand All @@ -128,9 +93,8 @@ def _test_is_fsdp_module() -> None:
model = FSDP(torch.nn.Linear(1, 1, device=device))
assert _is_fsdp_module(model)
model = torch.nn.Linear(1, 1, device=device)
if is_torch_version_geq_2_0():
fully_shard(model)
assert _is_fsdp_module(model)
fully_shard(model)
assert _is_fsdp_module(model)

@skip_if_not_distributed
@skip_if_not_gpu
Expand Down
45 changes: 1 addition & 44 deletions tests/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,48 +48,5 @@ def test_get_torch_version(self) -> None:
self.assertEqual(version.get_torch_version(), Version("1.12.0"))

def test_torch_version_comparators(self) -> None:
with patch.object(torch, "__version__", "1.7.0"):
self.assertFalse(version.is_torch_version_geq_1_8())
self.assertFalse(version.is_torch_version_geq_1_9())
self.assertFalse(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.8.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertFalse(version.is_torch_version_geq_1_9())
self.assertFalse(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.9.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertFalse(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.10.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertTrue(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.11.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertTrue(version.is_torch_version_geq_1_10())
self.assertTrue(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.12.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertTrue(version.is_torch_version_geq_1_10())
self.assertTrue(version.is_torch_version_geq_1_11())
self.assertTrue(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "2.0.0a0"):
self.assertTrue(version.is_torch_version_ge_1_13_1())
self.assertFalse(version.is_torch_version_geq_2_0())
self.assertFalse(version.is_torch_version_geq("2.1.0"))
11 changes: 0 additions & 11 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
TorchCompileParams,
)
from torchtnt.utils.swa import AveragedModel
from torchtnt.utils.version import is_torch_version_ge_1_13_1
from typing_extensions import Literal


Expand Down Expand Up @@ -166,8 +165,6 @@ def __init__(
torch_compile_params: Optional[TorchCompileParams] = None,
) -> None:
super().__init__()
if torch_compile_params:
_validate_torch_compile_available()

self.device: torch.device = device or init_from_env()
self.precision: Optional[torch.dtype] = (
Expand Down Expand Up @@ -879,11 +876,3 @@ def _update_lr_and_swa(self, state: State, number_of_steps_or_epochs: int) -> No
state, f"{self.__class__.__name__}.lr_scheduler_step"
):
self.step_lr_scheduler()


def _validate_torch_compile_available() -> None:
if not is_torch_version_ge_1_13_1():
raise RuntimeError(
"Torch compile support is available only in PyTorch 2.0 or higher. "
"Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/"
)
7 changes: 3 additions & 4 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from torch.distributed import checkpoint as dcp

from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torch.distributed.checkpoint.state_dict import _init_optim_state
from torch.distributed.checkpoint.stateful import Stateful
from torchtnt.framework.callbacks._checkpoint_utils import (
_prepare_app_state_for_checkpoint,
_prepare_app_state_for_restore,
Expand All @@ -39,8 +37,9 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful
from torchtnt.utils.stateful import MultiStateful, Stateful


logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,7 +248,7 @@ def restore(
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
_init_optim_state(optimizer)
init_optim_state(optimizer)

dcp.load(
{"app_state": MultiStateful(app_state)},
Expand Down
Loading
Loading