Skip to content

Commit

Permalink
Update learned_round.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 2, 2024
1 parent 1a2863b commit e5bc47c
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/brevitas/core/function_wrapper/learned_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ def __init__(self) -> None:
super(LearnedRoundIdentity, self).__init__()
self.tensor_clamp = TensorClampSte()

@brevitas.jit.ignore
def forward(self, p: torch.Tensor) -> torch.Tensor:
return self.tensor_clamp(
p,
min_val=torch.tensor(-0.5, device=p.device),
max_val=torch.tensor(+0.5, device=p.device))

@brevitas.jit.ignore
def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return round_ste(x + p)

Expand Down

0 comments on commit e5bc47c

Please sign in to comment.