From 85b81d4e7bb8b10fb78d4c928295d4e526fc3394 Mon Sep 17 00:00:00 2001 From: Michael Fuest Date: Mon, 14 Oct 2024 09:08:59 -0400 Subject: [PATCH] added checkpointing, loading and saving support --- config/model_config.yaml | 12 +- generator/data_generator.py | 39 +- generator/diffcharge/diffusion.py | 468 ++++++++++++++++--- generator/diffusion_ts/gaussian_diffusion.py | 120 +++-- generator/gan/acgan.py | 71 ++- generator/options.py | 1 + requirements.txt | 3 +- 7 files changed, 610 insertions(+), 104 deletions(-) diff --git a/config/model_config.yaml b/config/model_config.yaml index 2067d2e..4399b7f 100644 --- a/config/model_config.yaml +++ b/config/model_config.yaml @@ -1,4 +1,4 @@ -device: 1 # 0, cpu +device: 0 # 0, cpu seq_len: 96 # should not be changed for the current datasets input_dim: 2 # or 1 depending on user, but is dynamically set noise_dim: 256 @@ -6,6 +6,7 @@ cond_emb_dim: 64 shuffle: True sparse_conditioning_loss_weight: 0.8 # sparse conditioning training sample weight for loss computation [0, 1] freeze_cond_after_warmup: False # specify whether to freeze conditioning module parameters after warmup epochs +save_cycle: 200 # specify number of epochs to save model after conditioning_vars: # for each desired conditioning variable, add the name and number of categories month: 12 @@ -39,7 +40,7 @@ diffusion_ts: base_lr: 1e-4 n_layer_enc: 4 n_layer_dec: 5 - d_model: 256 + d_model: 128 sampling_timesteps: null loss_type: l1 #l2 beta_schedule: cosine #linear @@ -52,9 +53,8 @@ diffusion_ts: padding_size: null use_ff: true reg_weight: null - results_folder: ./Checkpoints_syn + results_folder: ./checkpoints gradient_accumulate_every: 2 - save_cycle: 1000 ema_decay: 0.99 ema_update_interval: 10 lr_scheduler_params: @@ -64,7 +64,9 @@ diffusion_ts: threshold: 1.0e-1 threshold_mode: rel verbose: false - warm_up_epochs: 100 + warm_up_epochs: 200 + use_ema_sampling: False + save_cycle: 1000 acgan: batch_size: 32 diff --git a/generator/data_generator.py b/generator/data_generator.py index df2011e..a6541f1 100644 --- a/generator/data_generator.py +++ b/generator/data_generator.py @@ -111,24 +111,45 @@ def sample_conditioning_vars(self, dataset, num_samples, random=False): """ return self.model.sample_conditioning_vars(dataset, num_samples, random) - def save(self, path): + def save(self, path: str): """ - Save the model to a file. + Save the model, optimizer, and EMA model to a checkpoint file. Args: - path (str): The file path to save the model. - """ - torch.save(self.model.state_dict(), path) + path (str): The file path to save the checkpoint to. + """ + if self.model is None: + raise ValueError("Model is not initialized. Cannot save checkpoint.") + + checkpoint = { + "epoch": getattr(self.model, "current_epoch", None), + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": ( + getattr(self.model, "optimizer", None).state_dict() + if hasattr(self.model, "optimizer") + else None + ), + "ema_state_dict": ( + getattr(self.model, "ema", None).ema_model.state_dict() + if hasattr(self.model, "ema") + else None + ), + } + + torch.save(checkpoint, path) + print(f"Saved checkpoint to {path}") def load(self, path): """ - Load the model from a file. + Load the model, optimizer, and EMA model from a checkpoint file. Args: - path (str): The file path to load the model from. + path (str): The file path to load the checkpoint from. """ - self.model.load_state_dict(torch.load(path)) - self.model.to(self.device) + if self.model is None: + raise ValueError("Model is not initialized. Cannot load checkpoint.") + + self.model.load(path) def _prepare_dataset( self, df: pd.DataFrame, timeseries_colname: str, conditioning_vars: Dict = None diff --git a/generator/diffcharge/diffusion.py b/generator/diffcharge/diffusion.py index d5dc058..2609f04 100644 --- a/generator/diffcharge/diffusion.py +++ b/generator/diffcharge/diffusion.py @@ -1,7 +1,28 @@ -import numpy as np -import scipy.signal as sig +""" +This class is adapted/taken from the Diffusion_TS GitHub repository: + +Repository: https://github.com/Y-debug-sys/Diffusion-TS +Author: Xinyu Yuan +License: MIT License + +Modifications: +- Conditioning and sampling logic +- Added further functions and removed unused functionality +- Added conditioning module logic for rare and non-rare samples +- Implemented saving and loading functionality + +Note: Please ensure compliance with the repository's license and credit the original authors when using or distributing this code. +""" + +import copy +import math +import os +from functools import partial + import torch import torch.nn as nn +import torch.optim as optim +from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm from datasets.utils import prepare_dataloader @@ -9,11 +30,84 @@ from generator.diffcharge.network import CNN from generator.diffcharge.network import Attention from generator.diffusion_ts.gaussian_diffusion import cosine_beta_schedule +from generator.diffusion_ts.model_utils import default +from generator.diffusion_ts.model_utils import extract +from generator.diffusion_ts.model_utils import identity + + +def linear_beta_schedule(timesteps, device): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32).to( + device + ) + + +def cosine_beta_schedule(timesteps, device, s=0.004): + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float32).to(device) + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999).to(device) + + +class EMA: + """ + Exponential Moving Average (EMA) of model parameters. + """ + def __init__(self, model, beta, update_every, device): + """ + Initialize the EMA class. -class DDPM: + Args: + model (nn.Module): The model to apply EMA to. + beta (float): The decay rate for EMA. + update_every (int): Update EMA every 'update_every' steps. + device (torch.device): Device to store the EMA model. + """ + self.model = model + self.ema_model = copy.deepcopy(model).eval().to(device) + self.beta = beta + self.update_every = update_every + self.step = 0 + self.device = device + for param in self.ema_model.parameters(): + param.requires_grad = False + + def update(self): + """ + Update the EMA parameters. + """ + self.step += 1 + if self.step % self.update_every != 0: + return + with torch.no_grad(): + for ema_param, model_param in zip( + self.ema_model.parameters(), self.model.parameters() + ): + ema_param.data.mul_(self.beta).add_( + model_param.data, alpha=1.0 - self.beta + ) + + def forward(self, x): + """ + Forward pass using EMA model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output from EMA model. + """ + return self.ema_model(x) + + +class DDPM(nn.Module): def __init__(self, opt): - super().__init__() + super(DDPM, self).__init__() self.opt = opt self.device = opt.device @@ -22,7 +116,7 @@ def __init__(self, opt): categorical_dims=opt.categorical_dims, embedding_dim=opt.cond_emb_dim, device=opt.device, - ) + ).to(self.device) # Initialize the epsilon model if opt.network == "attention": @@ -36,11 +130,9 @@ def __init__(self, opt): beta_end = opt.beta_end if schedule == "linear": - self.beta = torch.linspace( - beta_start, beta_end, self.n_steps, device=self.device - ) + self.beta = linear_beta_schedule(self.n_steps, self.device) elif schedule == "cosine": - self.beta = cosine_beta_schedule(self.n_steps).to(self.device) + self.beta = cosine_beta_schedule(self.n_steps, self.device) else: self.beta = ( torch.linspace( @@ -68,20 +160,75 @@ def __init__(self, opt): self.optimizer, milestones=[p1, p2], gamma=0.1 ) + # Initialize EMA + self.ema = EMA( + self.eps_model, + beta=opt.ema_decay, + update_every=opt.ema_update_interval, + device=self.device, + ) + + # Initialize tracking variables + self.current_epoch = 0 + self.writer = SummaryWriter(log_dir=os.path.join("runs", "ddpm")) + def gather(self, const, t): + """ + Gather specific timestep constants. + + Args: + const (torch.Tensor): Constants tensor. + t (torch.Tensor): Timestep tensor. + + Returns: + torch.Tensor: Gathered constants. + """ return const.gather(-1, t).view(-1, 1, 1) def q_xt_x0(self, x0, t): + """ + Compute mean and variance for q(x_t | x_0). + + Args: + x0 (torch.Tensor): Original data. + t (torch.Tensor): Timesteps. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Mean and variance. + """ alpha_bar = self.gather(self.alpha_bar, t) mean = (alpha_bar.sqrt()) * x0 var = 1 - alpha_bar return mean, var def q_sample(self, x0, t, eps): + """ + Sample from q(x_t | x_0). + + Args: + x0 (torch.Tensor): Original data. + t (torch.Tensor): Timesteps. + eps (torch.Tensor): Noise. + + Returns: + torch.Tensor: Sampled x_t. + """ mean, var = self.q_xt_x0(x0, t) return mean + var.sqrt() * eps def p_sample(self, xt, c, t, guidance_scale=1.0): + """ + Sample from p(x_{t-1} | x_t). + + Args: + xt (torch.Tensor): Current data. + c (torch.Tensor): Conditioning variables. + t (torch.Tensor): Current timestep. + guidance_scale (float): Guidance scale for conditional generation. + + Returns: + torch.Tensor: Sampled x_{t-1}. + """ eps_theta_cond = self.eps_model(xt, c, t) eps_theta_uncond = self.eps_model(xt, torch.zeros_like(c), t) eps_theta = eps_theta_uncond + guidance_scale * ( @@ -99,79 +246,282 @@ def p_sample(self, xt, c, t, guidance_scale=1.0): return mean + var.sqrt() * z def cal_loss(self, x0, c, drop_prob=0.15): + """ + Calculate the training loss. + + Args: + x0 (torch.Tensor): Original data. + c (torch.Tensor): Conditioning variables. + drop_prob (float): Probability to drop conditioning for augmentation. + + Returns: + torch.Tensor: Computed loss. + """ batch_size = x0.shape[0] t = torch.randint(0, self.n_steps, (batch_size,), device=self.device) noise = torch.randn_like(x0) xt = self.q_sample(x0, t, eps=noise) - # Randomly drop conditioning + # Randomly drop conditioning for augmentation if torch.rand(1).item() < drop_prob: c = torch.zeros_like(c) eps_theta = self.eps_model(xt, c, t) return self.loss_func(noise, eps_theta) - def train_model(self, dataset): - batch_size = self.opt.batch_size - epoch_loss = [] - train_loader = prepare_dataloader(dataset, batch_size) + def train_model(self, train_dataset): + """ + Train the DDPM model. + + Args: + train_dataset (torch.utils.data.Dataset): Training dataset. + """ + self.train() + self.to(self.device) + + train_loader = prepare_dataloader( + train_dataset, self.opt.batch_size, shuffle=True + ) + + os.makedirs(self.opt.results_folder, exist_ok=True) - for epoch in range(self.opt.n_epochs): + for epoch in tqdm(range(self.opt.n_epochs), desc="Training"): + self.current_epoch = epoch + 1 batch_loss = [] - for i, (time_series_batch, categorical_vars) in enumerate( + for i, (time_series_batch, conditioning_vars_batch) in enumerate( tqdm(train_loader, desc=f"Epoch {epoch + 1}") ): x0 = time_series_batch.to(self.device) - # Get conditioning vector - c = self.conditioning_module(categorical_vars) + c = self.conditioning_module(conditioning_vars_batch).to(self.device) + + # Compute rare_mask after warm-up epochs + if epoch >= self.opt.warm_up_epochs: + with torch.no_grad(): + if self.opt.freeze_cond_after_warmup: + for param in self.conditioning_module.parameters(): + param.requires_grad = ( + False # Freeze conditioning module + ) + + batch_embeddings = self.conditioning_module( + conditioning_vars_batch + ) + self.conditioning_module.update_running_statistics( + batch_embeddings + ) + rare_mask = ( + self.conditioning_module.is_rare(batch_embeddings) + .to(self.device) + .float() + ) + else: + rare_mask = torch.zeros((x0.size(0),), device=self.device) + self.optimizer.zero_grad() - loss = self.cal_loss(x0, c, drop_prob=0.1) + + if epoch < self.opt.warm_up_epochs: + # Standard loss without separating rare and non-rare + loss = self.cal_loss(x0, c, drop_prob=0.1) + else: + # Separate loss for rare and non-rare samples + rare_indices = (rare_mask == 1.0).nonzero(as_tuple=True)[0] + non_rare_indices = (rare_mask == 0.0).nonzero(as_tuple=True)[0] + + loss_rare = torch.tensor(0.0, device=self.device) + loss_non_rare = torch.tensor(0.0, device=self.device) + + if len(rare_indices) > 0: + x0_rare = x0[rare_indices] + c_rare = c[rare_indices] + loss_rare = self.cal_loss(x0_rare, c_rare, drop_prob=0.0) + + if len(non_rare_indices) > 0: + x0_non_rare = x0[non_rare_indices] + c_non_rare = c[non_rare_indices] + loss_non_rare = self.cal_loss( + x0_non_rare, c_non_rare, drop_prob=0.0 + ) + + N_r = rare_mask.sum().item() + N_nr = (rare_mask == 0.0).sum().item() + N = x0.size(0) + _lambda = self.sparse_conditioning_loss_weight + + loss = ( + _lambda * (N_r / N) * loss_rare + + (1 - _lambda) * (N_nr / N) * loss_non_rare + ) + loss.backward() self.optimizer.step() + self.ema.update() + batch_loss.append(loss.item()) - epoch_loss.append(np.mean(batch_loss)) - print(f"epoch={epoch + 1}/{self.opt.n_epochs}, loss={epoch_loss[-1]}") - self.lr_scheduler.step() - def sample(self, n_samples, categorical_vars, smooth=True, guidance_scale=1.0): - c = self.conditioning_module(categorical_vars).to(self.device) - with torch.no_grad(): - self.eps_model.eval() - x = torch.randn(n_samples, self.opt.seq_len, self.opt.input_dim).to( - self.device - ) - for j in tqdm( - range(self.n_steps), desc=f"Sampling steps of {self.n_steps}" - ): - t = torch.full( - (n_samples,), - self.n_steps - j - 1, + # Optional: Logging per batch + # self.writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + i) + + epoch_mean_loss = sum(batch_loss) / len(batch_loss) + print(f"Epoch {epoch + 1}/{self.opt.n_epochs}, Loss: {epoch_mean_loss:.4f}") + + # Scheduler step + self.lr_scheduler.step(epoch_mean_loss) + + if (epoch + 1) % self.opt.save_cycle == 0: + checkpoint_path = os.path.join( + self.opt.results_folder, f"ddpm_checkpoint_epoch_{epoch + 1}.pt" + ) + self.save(checkpoint_path, self.current_epoch) + print(f"Saved checkpoint at {checkpoint_path}.") + + print("Training complete") + self.writer.close() + + def save(self, path: str, epoch: int = None): + """ + Save the DDPM model, optimizer, EMA state, and epoch number. + + Args: + path (str): The file path to save the checkpoint to. + epoch (int, optional): The current epoch number. Defaults to None. + """ + checkpoint = { + "epoch": epoch if epoch is not None else self.current_epoch, + "eps_model_state_dict": self.eps_model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "ema_state_dict": self.ema.ema_model.state_dict(), + "alpha_bar": self.alpha_bar.cpu(), + "beta": self.beta.cpu(), + } + torch.save(checkpoint, path) + print(f"Saved DDPM checkpoint to {path}") + + def load(self, path: str): + """ + Load the DDPM model, optimizer, EMA state, and epoch number from a checkpoint file. + + Args: + path (str): The file path to load the checkpoint from. + """ + if not os.path.exists(path): + raise FileNotFoundError(f"Checkpoint file not found at {path}") + + checkpoint = torch.load(path, map_location=self.device) + + # Load epsilon model state + if "eps_model_state_dict" in checkpoint: + self.eps_model.load_state_dict(checkpoint["eps_model_state_dict"]) + print("Loaded epsilon model state.") + else: + raise KeyError("Checkpoint does not contain 'eps_model_state_dict'.") + + # Load optimizer state + if "optimizer_state_dict" in checkpoint: + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + print("Loaded optimizer state.") + else: + print("No optimizer state found in checkpoint.") + + # Load EMA state + if "ema_state_dict" in checkpoint: + self.ema.ema_model.load_state_dict(checkpoint["ema_state_dict"]) + print("Loaded EMA model state.") + else: + print("No EMA state found in checkpoint.") + + # Load alpha_bar and beta if needed + if "alpha_bar" in checkpoint: + self.alpha_bar = checkpoint["alpha_bar"].to(self.device) + print("Loaded alpha_bar.") + if "beta" in checkpoint: + self.beta = checkpoint["beta"].to(self.device) + print("Loaded beta.") + + # Load epoch number + if "epoch" in checkpoint: + self.current_epoch = checkpoint["epoch"] + print(f"Loaded epoch number: {self.current_epoch}") + else: + print("No epoch information found in checkpoint.") + + # Move models to the correct device + self.eps_model.to(self.device) + self.conditioning_module.to(self.device) + self.ema.ema_model.to(self.device) + print(f"DDPM models moved to {self.device}.") + + def sample_conditioning_vars(self, dataset, batch_size, random=False): + """ + Sample conditioning variables from the dataset. + + Args: + dataset (torch.utils.data.Dataset): The dataset to sample from. + batch_size (int): Number of samples to generate. + random (bool): Whether to sample randomly or from the dataset. + + Returns: + Dict[str, torch.Tensor]: Dictionary of conditioning variables. + """ + conditioning_vars = {} + if random: + for var_name, num_categories in self.opt.categorical_dims.items(): + conditioning_vars[var_name] = torch.randint( + 0, + num_categories, + (batch_size,), dtype=torch.long, device=self.device, ) - x = self.p_sample(x, c, t, guidance_scale=guidance_scale) - if smooth: - for i in range(n_samples): - filtered_x = sig.medfilt(x[i].cpu().numpy(), kernel_size=(5, 1)) - x[i] = torch.tensor(filtered_x, dtype=torch.float32).to(self.device) - return x - - def sample_random_conditioning_vars(self, dataset, batch_size): - sampled_rows = dataset.data.sample(n=batch_size).reset_index(drop=True) - - categorical_vars = {} - for var_name in self.categorical_dims.keys(): - categorical_vars[var_name] = torch.tensor( - sampled_rows[var_name].values, device=self.device, dtype=torch.long - ) + else: + sampled_rows = dataset.data.sample(n=batch_size).reset_index(drop=True) + for var_name in self.opt.categorical_dims.keys(): + conditioning_vars[var_name] = torch.tensor( + sampled_rows[var_name].values, dtype=torch.long, device=self.device + ) - return categorical_vars + return conditioning_vars - def generate(self, categorical_vars): - num_samples = categorical_vars[next(iter(categorical_vars))].shape[0] - return self.sample( - n_samples=num_samples, - categorical_vars=categorical_vars, - smooth=True, - guidance_scale=self.opt.guidance_scale, - ) + def generate(self, conditioning_vars: dict, use_ema_sampling: bool = False): + """ + Generate synthetic time series data using the trained model. + + Args: + conditioning_vars (dict): Conditioning variables for generation. + use_ema_sampling (bool, optional): Whether to use EMA model for generation. Defaults to False. + + Returns: + torch.Tensor: Generated synthetic time series data. + """ + num_samples = conditioning_vars[next(iter(conditioning_vars))].shape[0] + shape = (num_samples, self.opt.seq_len, self.opt.input_dim) + + if use_ema_sampling: + print("Generating using EMA model parameters.") + with torch.no_grad(): + samples = self.ema.ema_model.sample(shape, conditioning_vars) + return samples.cpu().numpy() + else: + print("Generating using regular epsilon model parameters.") + return self._generate(shape, conditioning_vars).cpu().numpy() + + @torch.no_grad() + def _generate(self, shape, conditioning_vars): + """ + Internal method to generate samples using the standard sampling procedure. + + Args: + shape (tuple): Shape of the samples to generate. + conditioning_vars (dict): Conditioning variables for generation. + + Returns: + torch.Tensor: Generated samples. + """ + device = self.beta.device + img = torch.randn(shape, device=device) + for t in tqdm( + reversed(range(0, self.n_steps)), + desc="sampling loop time step", + total=self.n_steps, + ): + img = self.p_sample(img, conditioning_vars, t) + return img diff --git a/generator/diffusion_ts/gaussian_diffusion.py b/generator/diffusion_ts/gaussian_diffusion.py index 9a01109..9846013 100644 --- a/generator/diffusion_ts/gaussian_diffusion.py +++ b/generator/diffusion_ts/gaussian_diffusion.py @@ -304,7 +304,11 @@ def sample_conditioning_vars(self, dataset, batch_size, random=False): def generate(self, conditioning_vars): num_samples = len(conditioning_vars[list(conditioning_vars.keys())[0]]) shape = (num_samples, self.seq_len, self.input_dim) - return self._generate(shape, conditioning_vars) + + if self.opt.use_ema_sampling: + return self.ema.ema_model._generate(shape, conditioning_vars) + else: + return self._generate(shape, conditioning_vars) def _generate(self, shape, conditioning_vars): self.eval() @@ -380,6 +384,7 @@ def forward(self, x, conditioning_vars=None, **kwargs): ) def train_model(self, train_dataset): + self.train() self.to(self.device) train_loader = DataLoader( @@ -401,7 +406,7 @@ def train_model(self, train_dataset): beta=self.opt.ema_decay, update_every=self.opt.ema_update_interval, device=self.device, - ).to(self.device) + ) self.scheduler = ReduceLROnPlateau( self.optimizer, **self.opt.lr_scheduler_params ) @@ -409,6 +414,7 @@ def train_model(self, train_dataset): self.conditioning_module.to(self.device) for epoch in tqdm(range(self.opt.n_epochs), desc="Training"): + self.current_epoch = epoch + 1 total_loss = 0.0 for i, (time_series_batch, conditioning_vars_batch) in enumerate( train_loader @@ -423,18 +429,23 @@ def train_model(self, train_dataset): current_batch_size = time_series_batch.size(0) if epoch > self.warm_up_epochs: - batch_embeddings = self.conditioning_module(conditioning_vars_batch) - self.conditioning_module.update_running_statistics(batch_embeddings) + with torch.no_grad(): - if self.opt.freeze_cond_after_warmup: - for param in self.conditioning_module.parameters(): - param.requires_grad = False # if specified, freeze conditioning module training + if self.opt.freeze_cond_after_warmup: + for param in self.conditioning_module.parameters(): + param.requires_grad = False # if specified, freeze conditioning module training - rare_mask = ( - self.conditioning_module.is_rare(batch_embeddings) - .to(self.device) - .float() - ) + batch_embeddings = self.conditioning_module( + conditioning_vars_batch + ) + self.conditioning_module.update_running_statistics( + batch_embeddings + ) + rare_mask = ( + self.conditioning_module.is_rare(batch_embeddings) + .to(self.device) + .float() + ) self.optimizer.zero_grad() @@ -492,27 +503,80 @@ def train_model(self, train_dataset): self.scheduler.step(total_loss) - if (epoch + 1) % self.opt.save_cycle == 0: - checkpoint_path = os.path.join( - self.opt.results_folder, f"checkpoint-{epoch + 1}.pt" - ) - torch.save( - { - "epoch": epoch + 1, - "model_state_dict": self.state_dict(), - "optimizer_state_dict": self.optimizer.state_dict(), - "ema_state_dict": self.ema.state_dict(), - }, - checkpoint_path, - ) - print(f"Saved checkpoint at {checkpoint_path}") + if (epoch + 1) % self.opt.save_cycle == 0: + checkpoint_path = os.path.join( + self.opt.results_folder, f"checkpoint-{epoch + 1}.pt" + ) + self.save(checkpoint_path, self.current_epoch) + print(f"Saved checkpoint at {checkpoint_path}.") print("Training complete") + def load(self, path: str): + """ + Load the model, optimizer, and EMA model from a checkpoint file. + + Args: + path (str): The file path to load the checkpoint from. + """ + checkpoint = torch.load(path, map_location=self.device) + + # Load the regular model state + if "model_state_dict" in checkpoint: + self.load_state_dict(checkpoint["model_state_dict"]) + print("Loaded regular model state.") + else: + raise KeyError("Checkpoint does not contain 'model_state_dict'.") + + if "optimizer_state_dict" in checkpoint and hasattr(self, "optimizer"): + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + print("Loaded optimizer state.") + else: + print( + "No optimizer state found in checkpoint or optimizer not initialized." + ) + + if "ema_state_dict" in checkpoint and hasattr(self, "ema"): + self.ema.ema_model.load_state_dict(checkpoint["ema_state_dict"]) + print("Loaded EMA model state.") + else: + print("No EMA state found in checkpoint or EMA not initialized.") + + if "epoch" in checkpoint: + self.current_epoch = checkpoint["epoch"] + print(f"Loaded epoch number: {self.current_epoch}") + else: + print("No epoch information found in checkpoint.") + + self.to(self.device) + if hasattr(self, "ema") and self.ema.ema_model: + self.ema.ema_model.to(self.device) + print(f"Model and EMA model moved to {self.device}.") + + def save(self, path: str, epoch: int = None): + model_state_dict_cpu = {k: v.cpu() for k, v in self.state_dict().items()} + optimizer_state_dict_cpu = { + k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in self.optimizer.state_dict().items() + } + ema_state_dict_cpu = { + k: v.cpu() for k, v in self.ema.ema_model.state_dict().items() + } + torch.save( + { + "epoch": epoch + 1, + "model_state_dict": model_state_dict_cpu, + "optimizer_state_dict": optimizer_state_dict_cpu, + "ema_state_dict": ema_state_dict_cpu, + }, + path, + ) + -class EMA(nn.Module): +class EMA: def __init__(self, model, beta, update_every, device): super(EMA, self).__init__() + self.model = model self.ema_model = copy.deepcopy(model).eval().to(device) self.beta = beta self.update_every = update_every @@ -527,7 +591,7 @@ def update(self): return with torch.no_grad(): for ema_param, model_param in zip( - self.ema_model.parameters(), self.parameters() + self.ema_model.parameters(), self.model.parameters() ): ema_param.data.mul_(self.beta).add_( model_param.data, alpha=1.0 - self.beta diff --git a/generator/gan/acgan.py b/generator/gan/acgan.py index b582672..623dbf2 100644 --- a/generator/gan/acgan.py +++ b/generator/gan/acgan.py @@ -1,5 +1,5 @@ """ -This class is adapted/taken from the synthetic-timeseries-smart-grid GitHub repository: +This class is adapted from the synthetic-timeseries-smart-grid GitHub repository: Repository: https://github.com/vermouth1992/synthetic-time-series-smart-grid Author: Chi Zhang @@ -203,6 +203,7 @@ def train_model(self, dataset): previous_embedding_covariance = None for epoch in range(num_epoch): + self.current_epoch = epoch + 1 for batch_index, (time_series_batch, conditioning_vars_batch) in enumerate( tqdm(train_loader, desc=f"Epoch {epoch + 1}") ): @@ -338,6 +339,13 @@ def train_model(self, dataset): self.generator.conditioning_module.cov_embedding.clone() ) + if (epoch + 1) % self.opt.save_cycle == 0: + checkpoint_path = os.path.join( + self.opt.results_folder, f"acgan_checkpoint_epoch_{epoch + 1}.pt" + ) + self.save(checkpoint_path, self.current_epoch) + print(f"Saved checkpoint at {checkpoint_path}.") + self.writer.close() def sample_conditioning_vars(self, dataset, batch_size, random=False): @@ -366,3 +374,64 @@ def generate(self, conditioning_vars): with torch.no_grad(): generated_data = self.generator(noise, conditioning_vars) return generated_data + + def save(self, path: str, epoch: int = None): + """ + Save the generator and discriminator models, optimizers, and epoch number. + + Args: + path (str): The file path to save the checkpoint to. + epoch (int, optional): The current epoch number. Defaults to None. + """ + checkpoint = { + "epoch": epoch if epoch is not None else self.current_epoch, + "generator_state_dict": self.generator.state_dict(), + "discriminator_state_dict": self.discriminator.state_dict(), + "optimizer_G_state_dict": self.optimizer_G.state_dict(), + "optimizer_D_state_dict": self.optimizer_D.state_dict(), + } + torch.save(checkpoint, path) + print(f"Saved ACGAN checkpoint to {path}") + + def load(self, path: str): + """ + Load the generator and discriminator models, optimizers, and epoch number from a checkpoint file. + + Args: + path (str): The file path to load the checkpoint from. + """ + checkpoint = torch.load(path, map_location=self.device) + + if "generator_state_dict" in checkpoint: + self.generator.load_state_dict(checkpoint["generator_state_dict"]) + print("Loaded generator state.") + else: + raise KeyError("Checkpoint does not contain 'generator_state_dict'.") + + if "discriminator_state_dict" in checkpoint: + self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) + print("Loaded discriminator state.") + else: + raise KeyError("Checkpoint does not contain 'discriminator_state_dict'.") + + if "optimizer_G_state_dict" in checkpoint: + self.optimizer_G.load_state_dict(checkpoint["optimizer_G_state_dict"]) + print("Loaded generator optimizer state.") + else: + print("No generator optimizer state found in checkpoint.") + + if "optimizer_D_state_dict" in checkpoint: + self.optimizer_D.load_state_dict(checkpoint["optimizer_D_state_dict"]) + print("Loaded discriminator optimizer state.") + else: + print("No discriminator optimizer state found in checkpoint.") + + if "epoch" in checkpoint: + self.current_epoch = checkpoint["epoch"] + print(f"Loaded epoch number: {self.current_epoch}") + else: + print("No epoch information found in checkpoint.") + + self.generator.to(self.device) + self.discriminator.to(self.device) + print(f"ACGAN models moved to {self.device}.") diff --git a/generator/options.py b/generator/options.py index 8acdc3b..3147169 100644 --- a/generator/options.py +++ b/generator/options.py @@ -99,6 +99,7 @@ def _load_diffusion_ts_params(self, model_params): self.ema_update_interval = model_params.ema_update_interval self.lr_scheduler_params = model_params.lr_scheduler_params self.warm_up_epochs = model_params.warm_up_epochs + self.use_ema_sampling = model_params.use_ema_sampling def _load_acgan_params(self, model_params): """ diff --git a/requirements.txt b/requirements.txt index bca5c5b..c7fc169 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,5 +22,4 @@ dtaidistance seaborn einops sentencepiece -hydra -omegaconfa +omegaconf