From 632e39632efd85b0c8d556bb162286dcaa0e74d9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 00:19:52 +0100 Subject: [PATCH] Update test_sign_sgd.py --- tests/brevitas/optim/test_sign_sgd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index 10b5dfe5b..5970ac262 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -51,7 +51,6 @@ from torch.optim.lr_scheduler import LinearLR from brevitas import torch_version -from brevitas.optim.sign_sgd import SignSGD from tests.conftest import SEED from tests.marker import requires_pt_ge @@ -80,8 +79,9 @@ class TestOptimSignSGD: @device_dtype_parametrize @pytest_cases.parametrize("lr", [0.1]) - @requires_pt_ge('2.1') + @requires_pt_ge('2.1') # TODO: revisit this def test_sign_sgd_single_update(self, device, dtype, lr): + from brevitas.optim.sign_sgd import SignSGD # Initialize weights and grads weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) # Initialize tensors to compute expected result @@ -103,6 +103,7 @@ def test_sign_sgd_single_update(self, device, dtype, lr): @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) @requires_pt_ge('2.1') def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + from brevitas.optim.sign_sgd import SignSGD # PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'): pytest.xfail( @@ -145,6 +146,7 @@ def closure(): @requires_pt_ge('2.1') def test_forloop_goes_right_direction_multigpu( self, dtype, optimizer_kwargs, lr_scheduler_args): + from brevitas.optim.sign_sgd import SignSGD optim_cls = SignSGD # Learnable parameters weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))