From b96323b520f5ddce837a75ccb733452e745d3b21 Mon Sep 17 00:00:00 2001 From: Michael Fuest Date: Wed, 9 Oct 2024 15:28:45 -0400 Subject: [PATCH] reworked conditioning module --- config/model_config.yaml | 4 +- datasets/timeseries_dataset.py | 6 +- eval/evaluator.py | 5 + generator/conditioning.py | 119 ++++++++++++++++--- generator/data_generator.py | 55 ++++++--- generator/diffusion_ts/gaussian_diffusion.py | 33 +++-- generator/gan/acgan.py | 92 +++++++++----- generator/options.py | 2 + main.py | 10 +- requirements.txt | 1 - setup.py | 1 - 11 files changed, 244 insertions(+), 84 deletions(-) diff --git a/config/model_config.yaml b/config/model_config.yaml index 1f0cece..2067d2e 100644 --- a/config/model_config.yaml +++ b/config/model_config.yaml @@ -5,6 +5,7 @@ noise_dim: 256 cond_emb_dim: 64 shuffle: True sparse_conditioning_loss_weight: 0.8 # sparse conditioning training sample weight for loss computation [0, 1] +freeze_cond_after_warmup: False # specify whether to freeze conditioning module parameters after warmup epochs conditioning_vars: # for each desired conditioning variable, add the name and number of categories month: 12 @@ -70,4 +71,5 @@ acgan: n_epochs: 200 lr_gen: 3e-4 lr_discr: 1e-4 - warm_up_epochs: 50 + warm_up_epochs: 100 + include_auxiliary_losses: True diff --git a/datasets/timeseries_dataset.py b/datasets/timeseries_dataset.py index edd9904..4650c5c 100644 --- a/datasets/timeseries_dataset.py +++ b/datasets/timeseries_dataset.py @@ -31,12 +31,14 @@ def __len__(self): return len(self.data) def __getitem__(self, idx): - time_series = self.data.iloc[idx][self.time_series_column] + sample = self.data.iloc[idx] + + time_series = sample[self.time_series_column] time_series = torch.tensor(time_series, dtype=torch.float32) conditioning_vars_dict = {} for var in self.conditioning_vars: - value = self.data.iloc[idx][var] + value = sample[var] conditioning_vars_dict[var] = torch.tensor(value, dtype=torch.long) return time_series, conditioning_vars_dict diff --git a/eval/evaluator.py b/eval/evaluator.py index debadb4..127020d 100644 --- a/eval/evaluator.py +++ b/eval/evaluator.py @@ -74,6 +74,11 @@ def evaluate_model( model = self.get_trained_model(dataset) + if model.opt.sparse_conditioning_loss_weight != 0.5: + data_label += "/rare_upweighed" + else: + data_label += "/equal_weight" + print("----------------------") if user_id is not None: print(f"Starting evaluation for user {user_id}") diff --git a/generator/conditioning.py b/generator/conditioning.py index d131904..4e458de 100644 --- a/generator/conditioning.py +++ b/generator/conditioning.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter class ConditioningModule(nn.Module): @@ -21,11 +22,10 @@ def __init__(self, categorical_dims, embedding_dim, device): nn.Linear(128, embedding_dim), ).to(device) - # Variables for collecting embeddings and computing Gaussian parameters - self.embeddings_list = [] self.mean_embedding = None self.cov_embedding = None self.inverse_cov_embedding = None # For Mahalanobis distance + self.n_samples = 0 def forward(self, categorical_vars): embeddings = [] @@ -36,28 +36,65 @@ def forward(self, categorical_vars): conditioning_vector = self.mlp(conditioning_matrix) return conditioning_vector - def collect_embeddings(self, categorical_vars): + def initialize_statistics(self, embeddings): """ - Collect conditional embeddings during the warm-up period. + Initialize mean and covariance using the embeddings from the first batch after warm-up. """ - with torch.no_grad(): - embedding = self.forward(categorical_vars) - self.embeddings_list.append(embedding.cpu()) + self.mean_embedding = torch.mean(embeddings, dim=0) + centered_embeddings = embeddings - self.mean_embedding.unsqueeze(0) + cov_matrix = torch.matmul(centered_embeddings.T, centered_embeddings) / ( + embeddings.size(0) - 1 + ) + cov_matrix += ( + torch.eye(cov_matrix.size(0)).to(self.device) * 1e-6 + ) # For numerical stability + self.cov_embedding = cov_matrix + self.inverse_cov_embedding = torch.inverse(self.cov_embedding) + self.n_samples = embeddings.size(0) - def compute_gaussian_parameters(self): + def update_running_statistics(self, embeddings): """ - Compute mean and covariance of the collected embeddings. + Update mean and covariance using an online algorithm. """ - all_embeddings = torch.cat(self.embeddings_list, dim=0) - self.mean_embedding = torch.mean(all_embeddings, dim=0).to(self.device) - # Compute covariance - centered_embeddings = all_embeddings - self.mean_embedding.cpu() - cov_matrix = torch.matmul(centered_embeddings.T, centered_embeddings) / ( - all_embeddings.size(0) - 1 + batch_size = embeddings.size(0) + if self.n_samples == 0: + self.initialize_statistics(embeddings) + return + + new_mean = torch.mean(embeddings, dim=0) + delta = new_mean - self.mean_embedding + total_samples = self.n_samples + batch_size + + # Update mean + self.mean_embedding = ( + self.n_samples * self.mean_embedding + batch_size * new_mean + ) / total_samples + + # Compute batch covariance + centered_embeddings = embeddings - new_mean.unsqueeze(0) + batch_cov = torch.matmul(centered_embeddings.T, centered_embeddings) / ( + batch_size - 1 ) - cov_matrix += torch.eye(cov_matrix.size(0)) * 1e-6 - self.cov_embedding = cov_matrix.to(self.device) - self.inverse_cov_embedding = torch.inverse(self.cov_embedding) + batch_cov += ( + torch.eye(batch_cov.size(0)).to(self.device) * 1e-6 + ) # For numerical stability + + # Update covariance + delta_outer = ( + torch.ger(delta, delta) * self.n_samples * batch_size / (total_samples**2) + ) + self.cov_embedding = ( + self.n_samples * self.cov_embedding + batch_size * batch_cov + delta_outer + ) / total_samples + + self.n_samples = total_samples + + # Update inverse covariance matrix + cov_embedding_reg = ( + self.cov_embedding + + torch.eye(self.cov_embedding.size(0)).to(self.device) * 1e-6 + ) + self.inverse_cov_embedding = torch.inverse(cov_embedding_reg) def compute_mahalanobis_distance(self, embeddings): """ @@ -85,3 +122,49 @@ def is_rare(self, embeddings, threshold=None, percentile=0.8): threshold = torch.quantile(mahalanobis_distance, percentile) rare_mask = mahalanobis_distance > threshold return rare_mask + + def log_embedding_statistics( + self, + epoch, + writer, + previous_mean_embedding, + previous_cov_embedding, + batch_embeddings, + ): + """ + Log embedding statistics to TensorBoard. + """ + # Log current mean norm and covariance norm + + if previous_mean_embedding is not None: + mean_embedding_norm = torch.norm(self.mean_embedding).item() + cov_embedding_norm = torch.norm(self.cov_embedding).item() + else: + mean_embedding_norm = 0 + cov_embedding_norm = 0 + + writer.add_scalar("Embedding/Mean_Norm", mean_embedding_norm, epoch) + writer.add_scalar("Embedding/Covariance_Norm", cov_embedding_norm, epoch) + + # Log changes in mean and covariance norms + if previous_mean_embedding is not None: + mean_diff = torch.norm(self.mean_embedding - previous_mean_embedding).item() + writer.add_scalar("Embedding/Mean_Difference", mean_diff, epoch) + if previous_cov_embedding is not None: + cov_diff = torch.norm(self.cov_embedding - previous_cov_embedding).item() + writer.add_scalar("Embedding/Covariance_Difference", cov_diff, epoch) + + # Compute Mahalanobis distances for logging + sample_embeddings = batch_embeddings + if sample_embeddings.size(0) > 0: + mahalanobis_distances = self.compute_mahalanobis_distance( + sample_embeddings.to(self.device) + ) + writer.add_histogram( + "Embedding/Mahalanobis_Distances", mahalanobis_distances.cpu(), epoch + ) + + # Log rarity threshold + percentile = 0.8 # Same as used in is_rare() + threshold = torch.quantile(mahalanobis_distances, percentile).item() + writer.add_scalar("Embedding/Rarity_Threshold", threshold, epoch) diff --git a/generator/data_generator.py b/generator/data_generator.py index e74fca5..71bbf87 100644 --- a/generator/data_generator.py +++ b/generator/data_generator.py @@ -1,3 +1,5 @@ +from typing import Dict + import pandas as pd import torch @@ -47,18 +49,41 @@ def _initialize_model(self): else: raise ValueError(f"Model {self.model_name} not recognized.") - def fit(self, X): - """ - Train the model on the given dataset. - Args: - X: Input data. Should be a compatible dataset object or pandas DataFrame. - """ - if isinstance(X, pd.DataFrame): - dataset = self._prepare_dataset(X) - else: - dataset = X - self.model.train_model(dataset) +def fit(self, X): + """ + Train the model on the given dataset. + + Args: + X: Input data. Should be a compatible dataset object or pandas DataFrame. + """ + if isinstance(X, pd.DataFrame): + dataset = self._prepare_dataset(X) + else: + dataset = X + + sample_timeseries, sample_cond_vars = dataset[0] + expected_seq_len = self.model.opt.seq_len + assert ( + sample_timeseries.shape[0] == expected_seq_len + ), f"Expected timeseries length {expected_seq_len}, but got {sample_timeseries.shape[0]}" + + if ( + hasattr(self.model_params, "conditioning_vars") + and self.model_params.conditioning_vars + ): + for var in self.model_params.conditioning_vars: + assert ( + var in sample_cond_vars.keys() + ), f"Conditioning variable '{var}' specified in model_params.conditioning_vars not found in dataset" + + expected_input_dim = self.model.opt.input_dim + assert sample_timeseries.shape == ( + expected_seq_len, + expected_input_dim, + ), f"Expected timeseries shape ({expected_seq_len}, {expected_input_dim}), but got {sample_timeseries.shape}" + + self.model.train_model(dataset) def generate(self, conditioning_vars): """ @@ -105,7 +130,9 @@ def load(self, path): self.model.load_state_dict(torch.load(path)) self.model.to(self.device) - def _prepare_dataset(self, df: pd.DataFrame): + def _prepare_dataset( + self, df: pd.DataFrame, timeseries_colname: str, conditioning_vars: Dict = None + ): """ Convert a pandas DataFrame into the required dataset format. @@ -120,8 +147,8 @@ def _prepare_dataset(self, df: pd.DataFrame): elif isinstance(df, pd.DataFrame): dataset = TimeSeriesDataset( dataframe=df, - conditioning_vars=self.conditioning_vars, - time_series_column="timeseries", + conditioning_vars=conditioning_vars, + time_series_column=timeseries_colname, ) return dataset else: diff --git a/generator/diffusion_ts/gaussian_diffusion.py b/generator/diffusion_ts/gaussian_diffusion.py index f9aedbd..9a01109 100644 --- a/generator/diffusion_ts/gaussian_diffusion.py +++ b/generator/diffusion_ts/gaussian_diffusion.py @@ -422,18 +422,19 @@ def train_model(self, train_dataset): current_batch_size = time_series_batch.size(0) - if epoch < self.warm_up_epochs: - self.conditioning_module.collect_embeddings(conditioning_vars_batch) - elif epoch == self.warm_up_epochs and i == 0: - self.conditioning_module.compute_gaussian_parameters() - else: - with torch.no_grad(): - embeddings = self.conditioning_module(conditioning_vars_batch) - rare_mask = ( - self.conditioning_module.is_rare(embeddings) - .to(self.device) - .float() - ) + if epoch > self.warm_up_epochs: + batch_embeddings = self.conditioning_module(conditioning_vars_batch) + self.conditioning_module.update_running_statistics(batch_embeddings) + + if self.opt.freeze_cond_after_warmup: + for param in self.conditioning_module.parameters(): + param.requires_grad = False # if specified, freeze conditioning module training + + rare_mask = ( + self.conditioning_module.is_rare(batch_embeddings) + .to(self.device) + .float() + ) self.optimizer.zero_grad() @@ -470,8 +471,14 @@ def train_model(self, train_dataset): conditioning_vars=conditioning_vars_non_rare, ) + N_r = rare_mask.sum().item() + N_nr = (torch.logical_not(rare_mask)).sum().item() + N = current_batch_size _lambda = self.sparse_conditioning_loss_weight - loss = _lambda * loss_rare + (1 - _lambda) * loss_non_rare + loss = ( + _lambda * (N_r / N) * loss_rare + + (1 - _lambda) * (N_nr / N) * loss_non_rare + ) loss = loss / self.opt.gradient_accumulate_every loss.backward() diff --git a/generator/gan/acgan.py b/generator/gan/acgan.py index c6d898a..b582672 100644 --- a/generator/gan/acgan.py +++ b/generator/gan/acgan.py @@ -13,9 +13,12 @@ Note: Please ensure compliance with the repository's license and credit the original authors when using or distributing this code. """ +import os + import torch import torch.nn as nn import torch.optim as optim +from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm from datasets.utils import prepare_dataloader @@ -158,7 +161,7 @@ def __init__(self, opt): self.warm_up_epochs = opt.warm_up_epochs self.sparse_conditioning_loss_weight = opt.sparse_conditioning_loss_weight - # self.writer = SummaryWriter() + self.writer = SummaryWriter(log_dir=os.path.join("runs", "acgan")) assert ( self.seq_len % 8 == 0 @@ -196,6 +199,9 @@ def train_model(self, dataset): num_epoch = self.opt.n_epochs train_loader = prepare_dataloader(dataset, batch_size) + previous_mean_embedding = None + previous_embedding_covariance = None + for epoch in range(num_epoch): for batch_index, (time_series_batch, conditioning_vars_batch) in enumerate( tqdm(train_loader, desc=f"Epoch {epoch + 1}") @@ -216,21 +222,26 @@ def train_model(self, dataset): soft_zero, soft_one = 0, 0.95 - rare_mask = torch.ones((current_batch_size,)).to(self.device) + rare_mask = torch.zeros((current_batch_size,)).to(self.device) + + if epoch > self.warm_up_epochs: - if epoch < self.warm_up_epochs: - self.generator.conditioning_module.collect_embeddings( + batch_embeddings = self.generator.conditioning_module( conditioning_vars_batch ) - elif epoch == self.warm_up_epochs and batch_index == 0: - self.generator.conditioning_module.compute_gaussian_parameters() - else: - embeddings = self.generator.conditioning_module( - conditioning_vars_batch + self.generator.conditioning_module.update_running_statistics( + batch_embeddings + ) + + if self.opt.freeze_cond_after_warmup: + for param in self.generator.conditioning_module.parameters(): + param.requires_grad = False # if specified, freeze conditioning module training + + rare_mask = ( + self.generator.conditioning_module.is_rare(batch_embeddings) + .to(self.device) + .float() ) - rare_mask = self.generator.conditioning_module.is_rare( - embeddings - ).to(self.device) # --------------------- # Train Discriminator @@ -246,14 +257,15 @@ def train_model(self, dataset): d_fake_loss = self.adversarial_loss( fake_pred, torch.ones_like(fake_pred) * soft_zero ) - for var_name in self.categorical_dims.keys(): - labels = conditioning_vars_batch[var_name].to(self.device) - d_real_loss += self.auxiliary_loss( - aux_outputs_real[var_name], labels - ) - d_fake_loss += self.auxiliary_loss( - aux_outputs_fake[var_name], labels - ) + if self.opt.include_auxiliary_losses: + for var_name in self.categorical_dims.keys(): + labels = conditioning_vars_batch[var_name].to(self.device) + d_real_loss += self.auxiliary_loss( + aux_outputs_real[var_name], labels + ) + d_fake_loss += self.auxiliary_loss( + aux_outputs_fake[var_name], labels + ) d_loss = 0.5 * (d_real_loss + d_fake_loss) d_loss.backward() @@ -285,10 +297,18 @@ def train_model(self, dataset): * soft_one, ) _lambda = self.sparse_conditioning_loss_weight - g_loss = _lambda * g_loss_rare + (1 - _lambda) * g_loss_non_rare - for var_name in self.categorical_dims.keys(): - labels = gen_categorical_vars[var_name] - g_loss += self.auxiliary_loss(aux_outputs[var_name], labels) + N_r = rare_mask.sum().item() + N_nr = (torch.logical_not(rare_mask)).sum().item() + N = current_batch_size + g_loss = ( + _lambda * (N_r / N) * g_loss_rare + + (1 - _lambda) * (N_nr / N) * g_loss_non_rare + ) + + if self.opt.include_auxiliary_losses: + for var_name in self.categorical_dims.keys(): + labels = gen_categorical_vars[var_name] + g_loss += self.auxiliary_loss(aux_outputs[var_name], labels) g_loss.backward() self.optimizer_G.step() @@ -296,11 +316,29 @@ def train_model(self, dataset): # ------------------- # TensorBoard Logging # ------------------- + global_step = epoch * len(train_loader) + batch_index + # Log overall losses for both generator and discriminator - # Log overall losses for both generator and discriminator on the same chart - # self.writer.add_scalars('GAN Losses', {'Discriminator': d_loss.item(), 'Generator': g_loss.item()}, epoch * len(train_loader) + i) + # self.writer.add_scalars('Losses', {'Discriminator': d_loss.item(), 'Generator': g_loss.item()}, global_step) + + # End of epoch logging + if epoch > self.warm_up_epochs: + + self.generator.conditioning_module.log_embedding_statistics( + epoch, + self.writer, + previous_mean_embedding, + previous_embedding_covariance, + batch_embeddings, + ) + previous_mean_embedding = ( + self.generator.conditioning_module.mean_embedding.clone() + ) + previous_embedding_covariance = ( + self.generator.conditioning_module.cov_embedding.clone() + ) - # self.writer.close() + self.writer.close() def sample_conditioning_vars(self, dataset, batch_size, random=False): conditioning_vars = {} diff --git a/generator/options.py b/generator/options.py index eb08731..8acdc3b 100644 --- a/generator/options.py +++ b/generator/options.py @@ -36,6 +36,7 @@ def __init__(self, model_name: str): self.cond_emb_dim = config.cond_emb_dim self.shuffle = config.shuffle self.sparse_conditioning_loss_weight = config.sparse_conditioning_loss_weight + self.freeze_cond_after_warmup = config.freeze_cond_after_warmup self.categorical_dims = config.get("conditioning_vars", {}) if model_name == "diffcharge": @@ -111,3 +112,4 @@ def _load_acgan_params(self, model_params): self.lr_gen = model_params.lr_gen self.lr_discr = model_params.lr_discr self.warm_up_epochs = model_params.warm_up_epochs + self.include_auxiliary_losses = model_params.include_auxiliary_losses diff --git a/main.py b/main.py index ce80093..f3c838d 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,3 @@ -from torch.utils.tensorboard import SummaryWriter - from datasets.pecanstreet import PecanStreetDataManager from eval.evaluator import Evaluator @@ -38,11 +36,9 @@ def evaluate_single_dataset_model( # evaluator.evaluate_all_users() # evaluator.evaluate_all_non_pv_users() non_pv_user_evaluator.evaluate_model( - None, distinguish_rare=False, data_label="non_pv_users" - ) - pv_user_evaluator.evaluate_model( - None, distinguish_rare=False, data_label="pv_users" + None, distinguish_rare=True, data_label="non_pv_users" ) + pv_user_evaluator.evaluate_model(None, distinguish_rare=True, data_label="pv_users") def main(): @@ -50,7 +46,7 @@ def main(): # evaluate_individual_user_models("acgan", include_generation=True) # evaluate_individual_user_models("acgan", include_generation=False, normalization_method="date") evaluate_single_dataset_model( - "acgan", + "diffusion_ts", geography="california", include_generation=False, normalization_method="group", diff --git a/requirements.txt b/requirements.txt index 9e0c04c..bca5c5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,6 @@ accelerate torch torchvision tensorboard -tensorboardX pyyaml pre-commit mypy diff --git a/setup.py b/setup.py index ba32eb5..9dc3a83 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ "accelerate>=0.32.1", "torchvision>=0.18.1", "tensorboard>=2.5.0", - "tensorboardX>=2.6.2.2", "pyyaml>=6.0.1", "pre-commit>=3.5.0", "black>=24.4.2",