Skip to content

Commit

Permalink
Fix scope construction with equalized groups
Browse files Browse the repository at this point in the history
Before this it could happen that the groups where of size [3, 3, 3, ..., 82],
where the last group was basically "the rest" (modulo). Now "the rest" is
distributed among all other groups such that it becomes [3, 3, 4, 3, 4, 3, ...,
3].
  • Loading branch information
braun-steven committed Mar 12, 2024
1 parent b4b7fda commit 77dfb46
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions simple_einet/layers/factorized_leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,25 @@ def __init__(
self.num_features_out = num_features_out

# Size of the factorized groups of RVs
cardinality = int(np.round(self.num_features / self.num_features_out))
cardinality = int(np.floor(self.num_features / self.num_features_out))

# Construct equal group sizes, such that (sum(group_sizes) == num_features) and the are num_features_out groups
group_sizes = np.ones(self.num_features_out, dtype=int) * cardinality
rest = self.num_features - cardinality * self.num_features_out
for i in range(rest):
group_sizes[i] += 1
np.random.shuffle(group_sizes)

# Construct mapping of scopes from in_features -> out_features
scopes = torch.zeros(num_features, self.num_features_out, num_repetitions)
for r in range(num_repetitions):
idxs = torch.randperm(n=self.num_features)
offset = 0
for o in range(num_features_out):
low = o * cardinality
high = (o + 1) * cardinality
group_size = group_sizes[o]
low = offset
high = offset + group_size
offset = high
if o == num_features_out - 1:
high = self.num_features
scopes[idxs[low:high], o, r] = 1
Expand Down

0 comments on commit 77dfb46

Please sign in to comment.