From 1a2863be9eb10860172bafeb2d37cca08ea82052 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 3 Dec 2024 00:33:23 +0100 Subject: [PATCH] Update learned_round.py --- src/brevitas/core/function_wrapper/learned_round.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index cfc78d03f..07f5f5d0f 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -65,7 +65,8 @@ def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return floor_ste(x) + p -class LearnedRoundIdentity(brevitas.jit.ScriptModule): +# TODO: Restore JIT compatibility +class LearnedRoundIdentity(torch.nn.Module): """ Implementation for LearnedRound learned parameter Adapted from https://arxiv.org/abs/2309.05516 @@ -75,14 +76,12 @@ def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() self.tensor_clamp = TensorClampSte() - @brevitas.jit.script_method def forward(self, p: torch.Tensor) -> torch.Tensor: return self.tensor_clamp( p, min_val=torch.tensor(-0.5, device=p.device), max_val=torch.tensor(+0.5, device=p.device)) - @brevitas.jit.script_method def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return round_ste(x + p)