Skip to content

Commit

Permalink
added checkpointing, loading and saving support
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 14, 2024
1 parent 910fef2 commit 85b81d4
Show file tree
Hide file tree
Showing 7 changed files with 610 additions and 104 deletions.
12 changes: 7 additions & 5 deletions config/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
device: 1 # 0, cpu
device: 0 # 0, cpu
seq_len: 96 # should not be changed for the current datasets
input_dim: 2 # or 1 depending on user, but is dynamically set
noise_dim: 256
cond_emb_dim: 64
shuffle: True
sparse_conditioning_loss_weight: 0.8 # sparse conditioning training sample weight for loss computation [0, 1]
freeze_cond_after_warmup: False # specify whether to freeze conditioning module parameters after warmup epochs
save_cycle: 200 # specify number of epochs to save model after

conditioning_vars: # for each desired conditioning variable, add the name and number of categories
month: 12
Expand Down Expand Up @@ -39,7 +40,7 @@ diffusion_ts:
base_lr: 1e-4
n_layer_enc: 4
n_layer_dec: 5
d_model: 256
d_model: 128
sampling_timesteps: null
loss_type: l1 #l2
beta_schedule: cosine #linear
Expand All @@ -52,9 +53,8 @@ diffusion_ts:
padding_size: null
use_ff: true
reg_weight: null
results_folder: ./Checkpoints_syn
results_folder: ./checkpoints
gradient_accumulate_every: 2
save_cycle: 1000
ema_decay: 0.99
ema_update_interval: 10
lr_scheduler_params:
Expand All @@ -64,7 +64,9 @@ diffusion_ts:
threshold: 1.0e-1
threshold_mode: rel
verbose: false
warm_up_epochs: 100
warm_up_epochs: 200
use_ema_sampling: False
save_cycle: 1000

acgan:
batch_size: 32
Expand Down
39 changes: 30 additions & 9 deletions generator/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,45 @@ def sample_conditioning_vars(self, dataset, num_samples, random=False):
"""
return self.model.sample_conditioning_vars(dataset, num_samples, random)

def save(self, path):
def save(self, path: str):
"""
Save the model to a file.
Save the model, optimizer, and EMA model to a checkpoint file.
Args:
path (str): The file path to save the model.
"""
torch.save(self.model.state_dict(), path)
path (str): The file path to save the checkpoint to.
"""
if self.model is None:
raise ValueError("Model is not initialized. Cannot save checkpoint.")

checkpoint = {
"epoch": getattr(self.model, "current_epoch", None),
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": (
getattr(self.model, "optimizer", None).state_dict()
if hasattr(self.model, "optimizer")
else None
),
"ema_state_dict": (
getattr(self.model, "ema", None).ema_model.state_dict()
if hasattr(self.model, "ema")
else None
),
}

torch.save(checkpoint, path)
print(f"Saved checkpoint to {path}")

def load(self, path):
"""
Load the model from a file.
Load the model, optimizer, and EMA model from a checkpoint file.
Args:
path (str): The file path to load the model from.
path (str): The file path to load the checkpoint from.
"""
self.model.load_state_dict(torch.load(path))
self.model.to(self.device)
if self.model is None:
raise ValueError("Model is not initialized. Cannot load checkpoint.")

self.model.load(path)

def _prepare_dataset(
self, df: pd.DataFrame, timeseries_colname: str, conditioning_vars: Dict = None
Expand Down
Loading

0 comments on commit 85b81d4

Please sign in to comment.