Skip to content

Commit

Permalink
refactor(mixture): rename EinetMixture to Mixture
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Nov 6, 2024
1 parent 9f8b0b7 commit 163f16d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
59 changes: 38 additions & 21 deletions simple_einet/mixture.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from _operator import xor
from simple_einet.data import Shape
from simple_einet.conv_pc import ConvPc, ConvPcConfig
from collections import defaultdict
from typing import List, Optional, Sequence

Expand All @@ -8,24 +10,35 @@
from torch.utils.data import DataLoader

from simple_einet.einet import EinetConfig, Einet
from simple_einet.layers.distributions.binomial import Binomial
from simple_einet.type_checks import check_valid


class EinetMixture(nn.Module):
def __init__(self, n_components: int, einet_config: EinetConfig):
class Mixture(nn.Module):
def __init__(self, n_components: int, config: EinetConfig | ConvPcConfig, data_shape: Shape | None = None):
super().__init__()
self.n_components = check_valid(n_components, expected_type=int, lower_bound=1)
self.config = einet_config

einets = []

for i in range(n_components):
einets.append(Einet(einet_config))
self.config = config

models = []

# Construct single models
for _ in range(n_components):
if isinstance(config, EinetConfig):
model = Einet(config)
self.num_features = config.num_features
elif isinstance(config, ConvPcConfig):
assert data_shape is not None, "data_shape must not be None (required for ConvPc)"
self.num_features = data_shape.num_pixels * data_shape.channels
model = ConvPc(config, data_shape)
else:
raise ValueError(f"Unknown model config type: {type(config)}")
models.append(model)

self.einets: Sequence[Einet] = nn.ModuleList(einets)
self.models: Sequence[Einet | ConvPc] = nn.ModuleList(models)
self._kmeans = KMeans(n_clusters=self.n_components, mode="euclidean", verbose=1)
self.mixture_weights = nn.Parameter(torch.empty(n_components), requires_grad=False)
self.centroids = nn.Parameter(torch.empty(n_components, einet_config.num_features), requires_grad=False)
self.register_buffer("mixture_weights", torch.empty(n_components))
self.register_buffer("centroids", torch.empty(n_components, self.num_features))

@torch.no_grad()
def initialize(self, data: torch.Tensor = None, dataloader: DataLoader = None, device=None):
Expand All @@ -37,7 +50,7 @@ def initialize(self, data: torch.Tensor = None, dataloader: DataLoader = None, d
for batch in dataloader:
x, y = batch
l.append(x)
if sum([d.shape[0] for d in l]) > 10000:
if sum([d.shape[0] for d in l]) > 30000:
break

data = torch.cat(l, dim=0).to(device)
Expand All @@ -49,10 +62,14 @@ def initialize(self, data: torch.Tensor = None, dataloader: DataLoader = None, d

self.centroids.data = self._kmeans.centroids

# Scale centroids if necessary
if self.config.leaf_type == Binomial:
self.centroids.data = self.centroids.data * self.config.leaf_kwargs["total_count"]

def _predict_cluster(self, x, marginalized_scopes: Optional[List[int]] = None):
x = x.view(x.shape[0], -1) # input needs to be [n, d]
if marginalized_scopes is not None:
keep_idx = list(sorted([i for i in range(self.config.num_features) if i not in marginalized_scopes]))
keep_idx = list(sorted([i for i in range(self.num_features) if i not in marginalized_scopes]))
centroids = self.centroids[:, keep_idx]
x = x[:, keep_idx]
else:
Expand All @@ -71,15 +88,15 @@ def _separate_data_by_cluster(self, x: torch.Tensor, marginalized_scope: List[in
return separated_idxs, separated_data

def forward(self, x, marginalized_scope: torch.Tensor = None):
assert self._kmeans is not None, "EinetMixture has not been initialized yet."
assert self._kmeans is not None, "Mixture has not been initialized yet."

separated_idxs, separated_data = self._separate_data_by_cluster(x, marginalized_scope)

lls_result = []
data_idxs_all = []
for cluster_idx, data_list in separated_data.items():
data_tensor = torch.stack(data_list, dim=0)
lls = self.einets[cluster_idx](data_tensor)
lls = self.models[cluster_idx](data_tensor)

data_idxs = separated_idxs[cluster_idx]
for data_idx, ll in zip(data_idxs, lls):
Expand All @@ -98,14 +115,14 @@ def forward(self, x, marginalized_scope: torch.Tensor = None):

def sample(
self,
num_samples: int = None,
num_samples_per_cluster: int = None,
num_samples: Optional[int] = None,
num_samples_per_cluster: Optional[int] = None,
class_index=None,
evidence: torch.Tensor = None,
evidence: Optional[torch.Tensor] = None,
is_mpe: bool = False,
temperature_leaves: float = 1.0,
temperature_sums: float = 1.0,
marginalized_scopes: List[int] = None,
marginalized_scopes: Optional[List[int]] = None,
seed=None,
mpe_at_leaves: bool = False,
):
Expand Down Expand Up @@ -140,7 +157,7 @@ def sample(

samples_all = []
for cluster_idx, num_samples_cluster in separated_idxs.items():
samples = self.einets[cluster_idx].sample(
samples = self.models[cluster_idx].sample(
num_samples_cluster,
class_index=class_index,
evidence=evidence,
Expand Down Expand Up @@ -172,7 +189,7 @@ def sample(
evidence_idxs_all = []
for cluster_idx, evidence_pre_cluster in separated_data.items():
evidence_per_cluster = torch.stack(evidence_pre_cluster, dim=0)
samples = self.einets[cluster_idx].sample(
samples = self.models[cluster_idx].sample(
evidence=evidence_per_cluster,
is_mpe=is_mpe,
temperature_leaves=temperature_leaves,
Expand Down
6 changes: 3 additions & 3 deletions simple_einet/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,14 @@ def init_einet_stats(einet: "Einet", dataloader: torch.utils.data.DataLoader):

from simple_einet.layers.distributions.normal import Normal
from simple_einet.einet import Einet
from simple_einet.einet_mixture import EinetMixture
from simple_einet.mixture import Mixture

# Set leaf parameters for normal distribution
if einet.config.leaf_type == Normal:
if type(einet) == Einet:
einets = [einet]
elif type(einet) == EinetMixture:
einets = einet.einets
elif type(einet) == Mixture:
einets = einet.models
else:
raise ValueError(f"Invalid einet type: {type(einet)} -- must be Einet or EinetMixture.")

Expand Down

0 comments on commit 163f16d

Please sign in to comment.