Skip to content

Commit

Permalink
Update test_sign_sgd.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 2, 2024
1 parent eb498bc commit 6927700
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/brevitas/optim/test_sign_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@

class TestOptimSignSGD:

@requires_pt_ge('2.1')
@device_dtype_parametrize
@pytest_cases.parametrize("lr", [0.1])
@requires_pt_ge('2.1')
def test_sign_sgd_single_update(self, device, dtype, lr):
# Initialize weights and grads
weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype))
Expand All @@ -98,10 +98,10 @@ def test_sign_sgd_single_update(self, device, dtype, lr):

assert torch.allclose(weights, initial_weights - lr * weight_sign_grad)

@requires_pt_ge('2.1')
@device_dtype_parametrize
@pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS)
@pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS)
@requires_pt_ge('2.1')
def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args):
# PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half
if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'):
Expand Down Expand Up @@ -137,12 +137,12 @@ def closure():
else:
assert closure().item() < initial_value

@requires_pt_ge('2.1')
@pytest.mark.skipif(
torch.cuda.device_count() <= 1, reason="At least two GPUs are required for this test.")
@pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS)
@pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS)
@pytest_cases.parametrize("dtype", [torch.float16, torch.float32])
@requires_pt_ge('2.1')
def test_forloop_goes_right_direction_multigpu(
self, dtype, optimizer_kwargs, lr_scheduler_args):
optim_cls = SignSGD
Expand Down

0 comments on commit 6927700

Please sign in to comment.