Skip to content

Commit

Permalink
Update learned_round.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 2, 2024
1 parent 632e396 commit 1a2863b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/brevitas/core/function_wrapper/learned_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


class LearnedRoundIdentity(brevitas.jit.ScriptModule):
# TODO: Restore JIT compatibility
class LearnedRoundIdentity(torch.nn.Module):
"""
Implementation for LearnedRound learned parameter
Adapted from https://arxiv.org/abs/2309.05516
Expand All @@ -75,14 +76,12 @@ def __init__(self) -> None:
super(LearnedRoundIdentity, self).__init__()
self.tensor_clamp = TensorClampSte()

@brevitas.jit.script_method
def forward(self, p: torch.Tensor) -> torch.Tensor:
return self.tensor_clamp(
p,
min_val=torch.tensor(-0.5, device=p.device),
max_val=torch.tensor(+0.5, device=p.device))

@brevitas.jit.script_method
def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return round_ste(x + p)

Expand Down

0 comments on commit 1a2863b

Please sign in to comment.