Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 3, 2024
1 parent e5bc47c commit ba1344f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 4 additions & 6 deletions src/brevitas/core/function_wrapper/learned_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


# TODO: Restore JIT compatibility
class LearnedRoundIdentity(torch.nn.Module):
class LearnedRoundIdentity(brevitas.jit.ScriptModule):
"""
Implementation for LearnedRound learned parameter
Adapted from https://arxiv.org/abs/2309.05516
Expand All @@ -75,15 +74,14 @@ class LearnedRoundIdentity(torch.nn.Module):
def __init__(self) -> None:
super(LearnedRoundIdentity, self).__init__()
self.tensor_clamp = TensorClampSte()
self.upper_lower_bound = brevitas.jit.Attribute(0.5, float)

@brevitas.jit.ignore
def forward(self, p: torch.Tensor) -> torch.Tensor:
return self.tensor_clamp(
p,
min_val=torch.tensor(-0.5, device=p.device),
max_val=torch.tensor(+0.5, device=p.device))
min_val=torch.tensor(-self.upper_lower_bound).type_as(p),
max_val=torch.tensor(self.upper_lower_bound).type_as(p))

@brevitas.jit.ignore
def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return round_ste(x + p)

Expand Down
4 changes: 3 additions & 1 deletion tests/brevitas/optim/test_sign_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ class TestOptimSignSGD:

@device_dtype_parametrize
@pytest_cases.parametrize("lr", [0.1])
@requires_pt_ge('2.1') # TODO: revisit this
@requires_pt_ge('2.1') # TODO: revisit this
def test_sign_sgd_single_update(self, device, dtype, lr):
from brevitas.optim.sign_sgd import SignSGD

# Initialize weights and grads
weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype))
# Initialize tensors to compute expected result
Expand All @@ -104,6 +105,7 @@ def test_sign_sgd_single_update(self, device, dtype, lr):
@requires_pt_ge('2.1')
def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args):
from brevitas.optim.sign_sgd import SignSGD

# PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half
if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'):
pytest.xfail(
Expand Down

0 comments on commit ba1344f

Please sign in to comment.