diff --git a/torchmultimodal/models/masked_auto_encoder/swin_decoder.py b/torchmultimodal/models/masked_auto_encoder/swin_decoder.py index 12040691..5e39c5d5 100644 --- a/torchmultimodal/models/masked_auto_encoder/swin_decoder.py +++ b/torchmultimodal/models/masked_auto_encoder/swin_decoder.py @@ -81,8 +81,8 @@ def _make_pair_wise_relative_positions(self) -> None: relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() ) - relative_coordinates_log = torch.sign(relative_coordinates) * torch.log( - 1.0 + relative_coordinates.abs() + relative_coordinates_log = torch.sign(relative_coordinates) * torch.log1p( + relative_coordinates.abs() ) self.register_buffer( "relative_coordinates_log", relative_coordinates_log, persistent=False