Skip to content

Commit

Permalink
Add NaN placeholder in leaf forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Jan 10, 2024
1 parent 197266e commit c59971d
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions simple_einet/layers/distributions/abstract_leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def __init__(
# Marginalization constant
self.marginalization_constant = nn.Parameter(torch.zeros(1), requires_grad=False)

# Placeholder to replace nan values for the forward pass to circument errors in the torch distributions
# This value is distribution specific since it needs to be inside of the distribution support and might need to
# be adjusted
self.nan_placeholder = 0

def _apply_dropout(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies dropout to the input tensor `x` according to the dropout probability
Expand Down Expand Up @@ -242,8 +247,18 @@ def forward(self, x, marginalized_scopes: List[int]):
"""
# Forward through base distribution
d = self._get_base_distribution()
nan_mask = torch.isnan(x)
if nan_mask.any():
# Replace nans with some valid value
x = torch.where(torch.isnan(x), self.nan_placeholder, x)

# Perform forward pass
x = dist_forward(d, x)

# Set back to nan
if nan_mask.any():
x[nan_mask] = torch.nan

x = self._marginalize_input(x, marginalized_scopes)

return x
Expand Down

0 comments on commit c59971d

Please sign in to comment.