diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index cfc78d03f..07f5f5d0f 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -65,7 +65,8 @@ def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return floor_ste(x) + p -class LearnedRoundIdentity(brevitas.jit.ScriptModule): +# TODO: Restore JIT compatibility +class LearnedRoundIdentity(torch.nn.Module): """ Implementation for LearnedRound learned parameter Adapted from https://arxiv.org/abs/2309.05516 @@ -75,14 +76,12 @@ def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() self.tensor_clamp = TensorClampSte() - @brevitas.jit.script_method def forward(self, p: torch.Tensor) -> torch.Tensor: return self.tensor_clamp( p, min_val=torch.tensor(-0.5, device=p.device), max_val=torch.tensor(+0.5, device=p.device)) - @brevitas.jit.script_method def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return round_ste(x + p)