Skip to content

Commit

Permalink
reworked conditioning module
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 9, 2024
1 parent 35c367b commit b96323b
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 84 deletions.
4 changes: 3 additions & 1 deletion config/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ noise_dim: 256
cond_emb_dim: 64
shuffle: True
sparse_conditioning_loss_weight: 0.8 # sparse conditioning training sample weight for loss computation [0, 1]
freeze_cond_after_warmup: False # specify whether to freeze conditioning module parameters after warmup epochs

conditioning_vars: # for each desired conditioning variable, add the name and number of categories
month: 12
Expand Down Expand Up @@ -70,4 +71,5 @@ acgan:
n_epochs: 200
lr_gen: 3e-4
lr_discr: 1e-4
warm_up_epochs: 50
warm_up_epochs: 100
include_auxiliary_losses: True
6 changes: 4 additions & 2 deletions datasets/timeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ def __len__(self):
return len(self.data)

def __getitem__(self, idx):
time_series = self.data.iloc[idx][self.time_series_column]
sample = self.data.iloc[idx]

time_series = sample[self.time_series_column]
time_series = torch.tensor(time_series, dtype=torch.float32)

conditioning_vars_dict = {}
for var in self.conditioning_vars:
value = self.data.iloc[idx][var]
value = sample[var]
conditioning_vars_dict[var] = torch.tensor(value, dtype=torch.long)

return time_series, conditioning_vars_dict
5 changes: 5 additions & 0 deletions eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def evaluate_model(

model = self.get_trained_model(dataset)

if model.opt.sparse_conditioning_loss_weight != 0.5:
data_label += "/rare_upweighed"
else:
data_label += "/equal_weight"

print("----------------------")
if user_id is not None:
print(f"Starting evaluation for user {user_id}")
Expand Down
119 changes: 101 additions & 18 deletions generator/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter


class ConditioningModule(nn.Module):
Expand All @@ -21,11 +22,10 @@ def __init__(self, categorical_dims, embedding_dim, device):
nn.Linear(128, embedding_dim),
).to(device)

# Variables for collecting embeddings and computing Gaussian parameters
self.embeddings_list = []
self.mean_embedding = None
self.cov_embedding = None
self.inverse_cov_embedding = None # For Mahalanobis distance
self.n_samples = 0

def forward(self, categorical_vars):
embeddings = []
Expand All @@ -36,28 +36,65 @@ def forward(self, categorical_vars):
conditioning_vector = self.mlp(conditioning_matrix)
return conditioning_vector

def collect_embeddings(self, categorical_vars):
def initialize_statistics(self, embeddings):
"""
Collect conditional embeddings during the warm-up period.
Initialize mean and covariance using the embeddings from the first batch after warm-up.
"""
with torch.no_grad():
embedding = self.forward(categorical_vars)
self.embeddings_list.append(embedding.cpu())
self.mean_embedding = torch.mean(embeddings, dim=0)
centered_embeddings = embeddings - self.mean_embedding.unsqueeze(0)
cov_matrix = torch.matmul(centered_embeddings.T, centered_embeddings) / (
embeddings.size(0) - 1
)
cov_matrix += (
torch.eye(cov_matrix.size(0)).to(self.device) * 1e-6
) # For numerical stability
self.cov_embedding = cov_matrix
self.inverse_cov_embedding = torch.inverse(self.cov_embedding)
self.n_samples = embeddings.size(0)

def compute_gaussian_parameters(self):
def update_running_statistics(self, embeddings):
"""
Compute mean and covariance of the collected embeddings.
Update mean and covariance using an online algorithm.
"""
all_embeddings = torch.cat(self.embeddings_list, dim=0)
self.mean_embedding = torch.mean(all_embeddings, dim=0).to(self.device)
# Compute covariance
centered_embeddings = all_embeddings - self.mean_embedding.cpu()
cov_matrix = torch.matmul(centered_embeddings.T, centered_embeddings) / (
all_embeddings.size(0) - 1
batch_size = embeddings.size(0)
if self.n_samples == 0:
self.initialize_statistics(embeddings)
return

new_mean = torch.mean(embeddings, dim=0)
delta = new_mean - self.mean_embedding
total_samples = self.n_samples + batch_size

# Update mean
self.mean_embedding = (
self.n_samples * self.mean_embedding + batch_size * new_mean
) / total_samples

# Compute batch covariance
centered_embeddings = embeddings - new_mean.unsqueeze(0)
batch_cov = torch.matmul(centered_embeddings.T, centered_embeddings) / (
batch_size - 1
)
cov_matrix += torch.eye(cov_matrix.size(0)) * 1e-6
self.cov_embedding = cov_matrix.to(self.device)
self.inverse_cov_embedding = torch.inverse(self.cov_embedding)
batch_cov += (
torch.eye(batch_cov.size(0)).to(self.device) * 1e-6
) # For numerical stability

# Update covariance
delta_outer = (
torch.ger(delta, delta) * self.n_samples * batch_size / (total_samples**2)
)
self.cov_embedding = (
self.n_samples * self.cov_embedding + batch_size * batch_cov + delta_outer
) / total_samples

self.n_samples = total_samples

# Update inverse covariance matrix
cov_embedding_reg = (
self.cov_embedding
+ torch.eye(self.cov_embedding.size(0)).to(self.device) * 1e-6
)
self.inverse_cov_embedding = torch.inverse(cov_embedding_reg)

def compute_mahalanobis_distance(self, embeddings):
"""
Expand Down Expand Up @@ -85,3 +122,49 @@ def is_rare(self, embeddings, threshold=None, percentile=0.8):
threshold = torch.quantile(mahalanobis_distance, percentile)
rare_mask = mahalanobis_distance > threshold
return rare_mask

def log_embedding_statistics(
self,
epoch,
writer,
previous_mean_embedding,
previous_cov_embedding,
batch_embeddings,
):
"""
Log embedding statistics to TensorBoard.
"""
# Log current mean norm and covariance norm

if previous_mean_embedding is not None:
mean_embedding_norm = torch.norm(self.mean_embedding).item()
cov_embedding_norm = torch.norm(self.cov_embedding).item()
else:
mean_embedding_norm = 0
cov_embedding_norm = 0

writer.add_scalar("Embedding/Mean_Norm", mean_embedding_norm, epoch)
writer.add_scalar("Embedding/Covariance_Norm", cov_embedding_norm, epoch)

# Log changes in mean and covariance norms
if previous_mean_embedding is not None:
mean_diff = torch.norm(self.mean_embedding - previous_mean_embedding).item()
writer.add_scalar("Embedding/Mean_Difference", mean_diff, epoch)
if previous_cov_embedding is not None:
cov_diff = torch.norm(self.cov_embedding - previous_cov_embedding).item()
writer.add_scalar("Embedding/Covariance_Difference", cov_diff, epoch)

# Compute Mahalanobis distances for logging
sample_embeddings = batch_embeddings
if sample_embeddings.size(0) > 0:
mahalanobis_distances = self.compute_mahalanobis_distance(
sample_embeddings.to(self.device)
)
writer.add_histogram(
"Embedding/Mahalanobis_Distances", mahalanobis_distances.cpu(), epoch
)

# Log rarity threshold
percentile = 0.8 # Same as used in is_rare()
threshold = torch.quantile(mahalanobis_distances, percentile).item()
writer.add_scalar("Embedding/Rarity_Threshold", threshold, epoch)
55 changes: 41 additions & 14 deletions generator/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

import pandas as pd
import torch

Expand Down Expand Up @@ -47,18 +49,41 @@ def _initialize_model(self):
else:
raise ValueError(f"Model {self.model_name} not recognized.")

def fit(self, X):
"""
Train the model on the given dataset.

Args:
X: Input data. Should be a compatible dataset object or pandas DataFrame.
"""
if isinstance(X, pd.DataFrame):
dataset = self._prepare_dataset(X)
else:
dataset = X
self.model.train_model(dataset)
def fit(self, X):
"""
Train the model on the given dataset.
Args:
X: Input data. Should be a compatible dataset object or pandas DataFrame.
"""
if isinstance(X, pd.DataFrame):
dataset = self._prepare_dataset(X)
else:
dataset = X

sample_timeseries, sample_cond_vars = dataset[0]
expected_seq_len = self.model.opt.seq_len
assert (
sample_timeseries.shape[0] == expected_seq_len
), f"Expected timeseries length {expected_seq_len}, but got {sample_timeseries.shape[0]}"

if (
hasattr(self.model_params, "conditioning_vars")
and self.model_params.conditioning_vars
):
for var in self.model_params.conditioning_vars:
assert (
var in sample_cond_vars.keys()
), f"Conditioning variable '{var}' specified in model_params.conditioning_vars not found in dataset"

expected_input_dim = self.model.opt.input_dim
assert sample_timeseries.shape == (
expected_seq_len,
expected_input_dim,
), f"Expected timeseries shape ({expected_seq_len}, {expected_input_dim}), but got {sample_timeseries.shape}"

self.model.train_model(dataset)

def generate(self, conditioning_vars):
"""
Expand Down Expand Up @@ -105,7 +130,9 @@ def load(self, path):
self.model.load_state_dict(torch.load(path))
self.model.to(self.device)

def _prepare_dataset(self, df: pd.DataFrame):
def _prepare_dataset(
self, df: pd.DataFrame, timeseries_colname: str, conditioning_vars: Dict = None
):
"""
Convert a pandas DataFrame into the required dataset format.
Expand All @@ -120,8 +147,8 @@ def _prepare_dataset(self, df: pd.DataFrame):
elif isinstance(df, pd.DataFrame):
dataset = TimeSeriesDataset(
dataframe=df,
conditioning_vars=self.conditioning_vars,
time_series_column="timeseries",
conditioning_vars=conditioning_vars,
time_series_column=timeseries_colname,
)
return dataset
else:
Expand Down
33 changes: 20 additions & 13 deletions generator/diffusion_ts/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,19 @@ def train_model(self, train_dataset):

current_batch_size = time_series_batch.size(0)

if epoch < self.warm_up_epochs:
self.conditioning_module.collect_embeddings(conditioning_vars_batch)
elif epoch == self.warm_up_epochs and i == 0:
self.conditioning_module.compute_gaussian_parameters()
else:
with torch.no_grad():
embeddings = self.conditioning_module(conditioning_vars_batch)
rare_mask = (
self.conditioning_module.is_rare(embeddings)
.to(self.device)
.float()
)
if epoch > self.warm_up_epochs:
batch_embeddings = self.conditioning_module(conditioning_vars_batch)
self.conditioning_module.update_running_statistics(batch_embeddings)

if self.opt.freeze_cond_after_warmup:
for param in self.conditioning_module.parameters():
param.requires_grad = False # if specified, freeze conditioning module training

rare_mask = (
self.conditioning_module.is_rare(batch_embeddings)
.to(self.device)
.float()
)

self.optimizer.zero_grad()

Expand Down Expand Up @@ -470,8 +471,14 @@ def train_model(self, train_dataset):
conditioning_vars=conditioning_vars_non_rare,
)

N_r = rare_mask.sum().item()
N_nr = (torch.logical_not(rare_mask)).sum().item()
N = current_batch_size
_lambda = self.sparse_conditioning_loss_weight
loss = _lambda * loss_rare + (1 - _lambda) * loss_non_rare
loss = (
_lambda * (N_r / N) * loss_rare
+ (1 - _lambda) * (N_nr / N) * loss_non_rare
)

loss = loss / self.opt.gradient_accumulate_every
loss.backward()
Expand Down
Loading

0 comments on commit b96323b

Please sign in to comment.