Skip to content

Commit

Permalink
feat: add support for top-down structure in Einet
Browse files Browse the repository at this point in the history
Refactor and configuration update for modular model architectures.

- Added "--structure" argument to args for model structure specification ("top-down" or "bottom-up").
- Refined configuration in `config.yaml` for greater control over einet and convpc layers, with updates to channels, structure, and kernel settings.
- Refactored `main.py` and `main_pl.py` to streamline model structure handling and simplify output generation.
- Updated `models_pl.py` to support modular architecture selection between `Einet` and `ConvPc`.
- Corrected image shape for CelebA dataset in `data.py` to (3, 128, 128).
- Adjusted `einet.py` to support new top-down and bottom-up structure choices.

Test: Minor test adjustments for structure argument compatibility.
  • Loading branch information
braun-steven committed Nov 6, 2024
1 parent 163f16d commit a3f93f5
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 107 deletions.
1 change: 1 addition & 0 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def parse_args():

parser.add_argument("--log-weights", action="store_true", help="use log weights")
parser.add_argument("--num-devices", type=int, default=1, help="number of devices")
parser.add_argument("--structure", default="top-down", choices=["bottom-up", "top-down"], help="structure of the network")

# Parse args
args = parser.parse_args()
Expand Down
25 changes: 18 additions & 7 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ hydra:
file:
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log # Fixed in hydra-colorlog version 1.2.1

einet:
S: 10
I: 10
D: 3
R: 1
layer_type: "linsum"
structure: "top-down"

convpc:
channels: [8, 16, 16, 16]
order: "sum-prod"
structure: "top-down"
kernel_size: 2



# Default set of configurations.
data_dir: "${oc.env:DATA_DIR}/"
Expand All @@ -36,15 +51,9 @@ log_interval: 10
classification: False
device: "cuda"
debug: False
S: 10
I: 10
D: 3
R: 1
gpu: 0
epochs: 10
load_and_eval: False
layer_type: "linsum"
dist: "normal"
precision: "bf16-mixed"
group_tag: ???
tag: ???
Expand All @@ -54,6 +63,8 @@ profiler: ???
dataset: ???
num_classes: 10
init_leaf_data: False
einet_mixture: False
mixture: False
torch_compile: False
multivariate_cardinality: 2
dist: "normal"
model: "einet" # Can be one of "einet" or "convpc"
75 changes: 4 additions & 71 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from simple_einet.layers.distributions.binomial import Binomial
from simple_einet.layers.distributions.normal import RatNormal, Normal
from simple_einet.einet import Einet, EinetConfig
from simple_einet.einet_mixture import EinetMixture
from simple_einet.mixture import Mixture

import lightning as L

Expand All @@ -42,7 +42,7 @@ def log_likelihoods(outputs, targets=None):
return lls


def train(args, model: Union[Einet, EinetMixture], device, train_loader, optimizer, epoch):
def train(args, model: Union[Einet, Mixture], device, train_loader, optimizer, epoch):
model.train()

pbar = tqdm.tqdm(train_loader)
Expand All @@ -62,15 +62,8 @@ def train(args, model: Union[Einet, EinetMixture], device, train_loader, optimiz

optimizer.zero_grad()

if args.dist == Dist.PIECEWISE_LINEAR:
cache_leaf = True
cache_index = batch_idx
else:
cache_leaf = False
cache_index = None

# Generate outputs
outputs = model(data, cache_leaf=cache_leaf, cache_index=cache_index)
outputs = model(data)

if args.classification:
model.posterior(data)
Expand Down Expand Up @@ -191,6 +184,7 @@ def test(model, device, loader, tag):
leaf_kwargs=leaf_kwargs,
layer_type=args.layer,
dropout=0.0,
structure=args.structure,
)

fabric = L.Fabric(accelerator=args.device, devices=args.num_devices, precision="16-mixed")
Expand Down Expand Up @@ -227,67 +221,6 @@ def test(model, device, loader, tag):

train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)

if args.dist == Dist.PIECEWISE_LINEAR:
# Initialize the piecewise linear function
# Collect data
batches = []
count = 0
for data, _ in train_loader:
batches.append(data)
count += data.shape[0]
if count > 10000:
break
data_init_pwl = torch.cat(batches, dim=0)

# Prepare data
data_init_pwl = preprocess(
data_init_pwl,
n_bits,
n_bins,
dequantize=True,
has_gauss_dist=has_gauss_dist,
)

data_init_pwl = data_init_pwl.view(data_init_pwl.shape[0], data_init_pwl.shape[1], num_features)

domains = [Domain.discrete_range(min=0, max=255)] * num_features
with torch.no_grad():
model.leaf.base_leaf.initialize(data_init_pwl, domains=domains)

# Use mixture weights obtained in leaf initialization and set these to the first linsum layer weights
model.layers[0].logits.data[:] = model.leaf.base_leaf.mixture_weights.permute(1, 0).view(1, config.num_leaves, 1, config.num_repetitions).log()

# Visualize a couple of pixel distributions and their piecewise linear functions
# Select 20 random pixels
pixels = list(range(64))[::3]
# pixels = [36, 766, 720, 588, 759, 403, 664, 428, 25, 686, 673, 638, 44, 147, 610, 470, 540, 179, 698, 420]

d = model.leaf.base_leaf._get_base_distribution()
log_probs = d.log_prob(data_init_pwl)

xs = d.xs
ys = d.ys

for pixel in pixels:
# Get data subset
# xs_pixel = xs[pixel][0][0][0].squeeze()
# ys_pixel = ys[pixel][0][0][0].squeeze()
xs_pixel = xs[0][0][pixel][0].squeeze().cpu()
ys_pixel = ys[0][0][pixel][0].squeeze().cpu()

# Plot pixel distribution with pixel value as x and logprob as y values
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
plt.plot(xs_pixel, ys_pixel, label="PWL")

# Plot histogram of pixel values
plt.hist(data_init_pwl[:, :, pixel].flatten().cpu().numpy(), bins=100, density=True, alpha=0.5, label="Data")
plt.xlabel("Pixel Value")
plt.ylabel("Density")
plt.legend()
plt.savefig(os.path.join(result_dir, f"pwl-{pixel}.png"), dpi=300)
plt.close()

if args.train:
for epoch in range(1, args.epochs + 1):
Expand Down
5 changes: 2 additions & 3 deletions main_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def main(cfg: DictConfig):
logger.info("\n" + OmegaConf.to_yaml(cfg, resolve=True))
logger.info("Run dir: " + run_dir)

seed_everything(cfg.seed, workers=True)

if not cfg.wandb:
os.environ["WANDB_MODE"] = "offline"

Expand Down Expand Up @@ -87,6 +85,7 @@ def main(cfg: DictConfig):
num_workers=min(cfg.num_workers, os.cpu_count()),
loop=False,
normalize=normalize,
seed=cfg.seed,
)

# Create callbacks
Expand Down Expand Up @@ -120,7 +119,7 @@ def main(cfg: DictConfig):
# model = torch.compile(model)
raise NotImplementedError("Torch compilation not yet supported with einsum.")

if cfg.einet_mixture:
if cfg.mixture:
# If we chose a mixture of einets, we need to initialize the mixture weights
logger.info("Initializing Einet mixture weights")
model.spn.initialize(dataloader=train_loader, device=devices[0])
Expand Down
67 changes: 56 additions & 11 deletions models_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@
from rtpt import RTPT
from torch import nn

from simple_einet.conv_pc import ConvPcConfig, ConvPc
from simple_einet.data import get_data_shape
from simple_einet.dist import Dist, get_distribution
from simple_einet.einet import EinetConfig, Einet
from simple_einet.einet_mixture import EinetMixture
from simple_einet.mixture import Mixture

# Translate the dataloader index to the dataset name
DATALOADER_ID_TO_SET_NAME = {0: "train", 1: "val", 2: "test"}


def make_einet(cfg, num_classes: int = 1) -> EinetMixture | Einet:
def make_einet(cfg, num_classes: int = 1) -> Mixture | Einet:
"""
Make an Einet model based off the given arguments.
Expand All @@ -38,22 +39,55 @@ def make_einet(cfg, num_classes: int = 1) -> EinetMixture | Einet:
config = EinetConfig(
num_features=image_shape.num_pixels,
num_channels=image_shape.channels,
depth=cfg.D,
num_sums=cfg.S,
num_leaves=cfg.I,
num_repetitions=cfg.R,
depth=cfg.einet.D,
num_sums=cfg.einet.S,
num_leaves=cfg.einet.I,
num_repetitions=cfg.einet.R,
num_classes=num_classes,
leaf_kwargs=leaf_kwargs,
leaf_type=leaf_type,
dropout=cfg.dropout,
layer_type=cfg.layer_type,
layer_type=cfg.einet.layer_type,
structure=cfg.einet.structure,
)
if cfg.einet_mixture:
return EinetMixture(n_components=num_classes, einet_config=config)
if cfg.mixture:
return Mixture(n_components=num_classes, config=config)
else:
return Einet(config)


def make_convpc(cfg, num_classes: int = 1) -> Mixture | ConvPc:
"""
Make ConvPc model based off the given arguments.
Args:
cfg: Arguments parsed from argparse.
num_classes: Number of classes to model.
Returns:
ConvPc model.
"""

image_shape = get_data_shape(cfg.dataset)
# leaf_kwargs, leaf_type = {"total_count": 255}, Binomial
leaf_kwargs, leaf_type = get_distribution(dist=cfg.dist, cfg=cfg)

config = ConvPcConfig(
channels=cfg.convpc.channels,
num_channels=image_shape.channels,
num_classes=num_classes,
leaf_kwargs=leaf_kwargs,
leaf_type=leaf_type,
structure=cfg.convpc.structure,
order=cfg.convpc.order,
kernel_size=cfg.convpc.kernel_size,
)
if cfg.mixture:
return Mixture(n_components=num_classes, config=config, data_shape=image_shape)
else:
return ConvPc(config=config, data_shape=image_shape)


class LitModel(pl.LightningModule, ABC):
"""
LightningModule for training a model using PyTorch Lightning.
Expand Down Expand Up @@ -123,7 +157,13 @@ class SpnGenerative(LitModel):

def __init__(self, cfg: DictConfig, steps_per_epoch: int):
super().__init__(cfg=cfg, name="gen", steps_per_epoch=steps_per_epoch)
self.spn = make_einet(cfg)
if cfg.model == "einet":
self.spn = make_einet(cfg, num_classes=cfg.num_classes)
elif cfg.model == "convpc":
self.spn = make_convpc(cfg, num_classes=cfg.num_classes)
else:
raise ValueError(f"Unknown model {cfg.model}")


def training_step(self, train_batch, batch_idx):
data, labels = train_batch
Expand Down Expand Up @@ -209,7 +249,12 @@ def __init__(self, cfg: DictConfig, steps_per_epoch: int):
super().__init__(cfg, name="disc", steps_per_epoch=steps_per_epoch)

# Construct SPN
self.spn = make_einet(cfg, num_classes=10)
if cfg.model == "einet":
self.spn = make_einet(cfg, num_classes=cfg.num_classes)
elif cfg.model == "convpc":
self.spn = make_convpc(cfg, num_classes=cfg.num_classes)
else:
raise ValueError(f"Unknown model {cfg.model}")

# Define loss function
self.criterion = nn.NLLLoss()
Expand Down
2 changes: 1 addition & 1 deletion simple_einet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_data_shape(dataset_name: str) -> Shape:
"cifar": (3, 32, 32),
"svhn": (3, 32, 32),
"svhn-extra": (3, 32, 32),
"celeba": (3, 64, 64),
"celeba": (3, 128, 128),
"celeba-small": (3, 64, 64),
"celeba-tiny": (3, 32, 32),
"lsun": (3, 32, 32),
Expand Down
24 changes: 12 additions & 12 deletions simple_einet/einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class EinetConfig:
leaf_type: Type = None # Type of the leaf base class (Normal, Bernoulli, etc)
leaf_kwargs: Dict[str, Any] = field(default_factory=dict) # Parameters for the leaf base class
layer_type: str = "linsum" # Indicates the intermediate layer type: linsum or einsum
structure: str = "original" # Structure of the Einet: original or bottom_up
structure: str = "top-down" # Structure of the Einet: top-down or bottom-up

def assert_valid(self):
"""Check whether the configuration is valid."""
Expand All @@ -58,9 +58,9 @@ def assert_valid(self):
"einsum",
], f"Invalid layer type {self.layer_type}. Must be 'linsum' or 'einsum'."
assert self.structure in [
"original",
"bottom_up",
], f"Invalid structure type {self.structure}. Must be 'original' or 'bottom_up'."
"top-down",
"bottom-up",
], f"Invalid structure type {self.structure}. Must be 'top-down' or 'bottom-up'."

assert isinstance(self.leaf_type, type) and issubclass(
self.leaf_type, AbstractLeaf
Expand All @@ -72,7 +72,7 @@ def assert_valid(self):
else:
cardinality = 1

if self.structure == "bottom_up":
if self.structure == "bottom-up":
assert self.layer_type == "linsum", "Bottom-up structure only supports LinsumLayer due to handling of padding (not implemented for einsumlayer yet)."

# Get minimum number of features present at the lowest layer (num_features is the actual input dimension,
Expand Down Expand Up @@ -104,12 +104,12 @@ def __init__(self, config: EinetConfig):
self.config = config

# Construct the architecture
if self.config.structure == "original":
self._build_structure_original()
elif self.config.structure == "bottom_up":
if self.config.structure == "top-down":
self._build_structure_top_down()
elif self.config.structure == "bottom-up":
self._build_structure_bottom_up()
else:
raise ValueError(f"Invalid structure type {self.config.structure}. Must be 'original' or 'bottom_up'.")
raise ValueError(f"Invalid structure type {self.config.structure}. Must be '_riginal' or 'bottom-up'.")

# Leaf cache
self._leaf_cache = {}
Expand Down Expand Up @@ -235,9 +235,9 @@ def posterior(self, x) -> torch.Tensor:

return posterior(ll_x_g_y, self.config.num_classes)

def _build_structure_original(self):
def _build_structure_top_down(self):
"""Construct the internal architecture of the Einet."""
# Build the SPN bottom up:
# Build the SPN top down:
# Definition from RAT Paper
# Leaf Region: Create I leaf nodes
# Root Region: Create C sum nodes
Expand Down Expand Up @@ -473,7 +473,7 @@ def _build_structure_bottom_up(self):
)

def _build_input_distribution_bottom_up(self) -> AbstractLeaf:
"""Construct the input distribution layer. This constructs a direct leaf and not a FactorizedLeaf since the bottom_up approach does not factorize."""
"""Construct the input distribution layer. This constructs a direct leaf and not a FactorizedLeaf since the bottom-up approach does not factorize."""
# Cardinality is the size of the region in the last partitions
return self.config.leaf_type(
num_features=self.config.num_features,
Expand Down
Loading

0 comments on commit a3f93f5

Please sign in to comment.