Skip to content

Commit

Permalink
Update test_sign_sgd.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 2, 2024
1 parent 6927700 commit 632e396
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/brevitas/optim/test_sign_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from torch.optim.lr_scheduler import LinearLR

from brevitas import torch_version
from brevitas.optim.sign_sgd import SignSGD
from tests.conftest import SEED
from tests.marker import requires_pt_ge

Expand Down Expand Up @@ -80,8 +79,9 @@ class TestOptimSignSGD:

@device_dtype_parametrize
@pytest_cases.parametrize("lr", [0.1])
@requires_pt_ge('2.1')
@requires_pt_ge('2.1') # TODO: revisit this
def test_sign_sgd_single_update(self, device, dtype, lr):
from brevitas.optim.sign_sgd import SignSGD
# Initialize weights and grads
weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype))
# Initialize tensors to compute expected result
Expand All @@ -103,6 +103,7 @@ def test_sign_sgd_single_update(self, device, dtype, lr):
@pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS)
@requires_pt_ge('2.1')
def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args):
from brevitas.optim.sign_sgd import SignSGD
# PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half
if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'):
pytest.xfail(
Expand Down Expand Up @@ -145,6 +146,7 @@ def closure():
@requires_pt_ge('2.1')
def test_forloop_goes_right_direction_multigpu(
self, dtype, optimizer_kwargs, lr_scheduler_args):
from brevitas.optim.sign_sgd import SignSGD
optim_cls = SignSGD
# Learnable parameters
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
Expand Down

0 comments on commit 632e396

Please sign in to comment.