Skip to content

Commit

Permalink
Change logit initializer for categorical and binomial distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Jun 17, 2024
1 parent 1518d99 commit 0166706
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion simple_einet/layers/distributions/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(
self.total_count = check_valid(total_count, int, lower_bound=1)

# Create binomial parameters as unnormalized log probabilities
self.logits = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions))

p = 0.5 + (torch.rand(1, num_channels, num_features, num_leaves, num_repetitions) - 0.5) * 0.2
self.logits = nn.Parameter(probs_to_logits(p, is_binary=True))

def _get_base_distribution(self, ctx: SamplingContext = None):
# Cast logits to probabilities
Expand Down
4 changes: 3 additions & 1 deletion simple_einet/layers/distributions/categorical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch.distributions.utils import probs_to_logits
from torch import distributions as dist
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -27,7 +28,8 @@ def __init__(self, num_features: int, num_channels: int, num_leaves: int, num_re
super().__init__(num_features, num_channels, num_leaves, num_repetitions)

# Create logits
self.logits = nn.Parameter(torch.randn(1, num_channels, num_features, num_leaves, num_repetitions, num_bins))
p = 0.5 + (torch.rand(1, num_channels, num_features, num_leaves, num_repetitions, num_bins) - 0.5) * 0.2
self.logits = nn.Parameter(probs_to_logits(p))

def _get_base_distribution(self, ctx: SamplingContext = None):
# Use sigmoid to ensure, that probs are in valid range
Expand Down

0 comments on commit 0166706

Please sign in to comment.