From 35c367b61a1041f90b4cbff34df73d87172c1430 Mon Sep 17 00:00:00 2001 From: Michael Fuest Date: Fri, 4 Oct 2024 05:46:12 -0400 Subject: [PATCH] eval cleanups --- config/model_config.yaml | 4 +- eval/evaluator.py | 80 ++++++++++++++++++++++++++++++++++++++++ main.py | 8 ++-- 3 files changed, 87 insertions(+), 5 deletions(-) diff --git a/config/model_config.yaml b/config/model_config.yaml index 647fa5f..1f0cece 100644 --- a/config/model_config.yaml +++ b/config/model_config.yaml @@ -67,7 +67,7 @@ diffusion_ts: acgan: batch_size: 32 - n_epochs: 10 + n_epochs: 200 lr_gen: 3e-4 lr_discr: 1e-4 - warm_up_epochs: 5 + warm_up_epochs: 50 diff --git a/eval/evaluator.py b/eval/evaluator.py index d4f0cd4..debadb4 100644 --- a/eval/evaluator.py +++ b/eval/evaluator.py @@ -15,6 +15,9 @@ from eval.metrics import calculate_mmd from eval.metrics import calculate_period_bound_mse from eval.metrics import dynamic_time_warping_dist +from eval.metrics import plot_range_with_syn_values +from eval.metrics import plot_syn_with_closest_real_ts +from eval.metrics import visualization from eval.predictive_metric import predictive_score_metrics from generator.diffcharge.diffusion import DDPM from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS @@ -197,6 +200,9 @@ def evaluate_subset( # Compute metrics self.compute_metrics(real_data_array, syn_data_array, real_data_subset, writer) + # Generate plots + self.create_visualizations(real_data_inv, syn_data_inv, dataset, model, writer) + # Close the writer writer.flush() writer.close() @@ -250,6 +256,80 @@ def compute_metrics( writer.add_scalar("Predictive/score", pred_score) self.metrics["predictive"].append(pred_score) + def create_visualizations( + self, + real_data_df: pd.DataFrame, + syn_data_df: pd.DataFrame, + dataset: Any, + model: Any, + writer: SummaryWriter, + num_samples: int = 100, + num_runs: int = 3, + ): + """ + Create various visualizations for the evaluation results. + + Args: + real_data_df (pd.DataFrame): Inverse-transformed real data. + syn_data_df (pd.DataFrame): Inverse-transformed synthetic data. + dataset (Any): The dataset object. + model (Any): The trained model. + writer (SummaryWriter): TensorBoard writer for logging visualizations. + num_samples (int): Number of samples to generate for visualization. + num_runs (int): Number of visualization runs. + """ + for i in range(num_runs): + # Sample a conditioning variable combination from real data + sample_row = real_data_df.sample(n=1).iloc[0] + conditioning_vars_sample = { + var_name: torch.tensor( + [sample_row[var_name]] * num_samples, + dtype=torch.long, + device=device, + ) + for var_name in model.categorical_dims.keys() + } + + generated_samples = model.generate(conditioning_vars_sample).cpu().numpy() + if generated_samples.ndim == 2: + generated_samples = generated_samples.reshape( + generated_samples.shape[0], -1, generated_samples.shape[1] + ) + + generated_samples_df = pd.DataFrame( + { + var_name: [sample_row[var_name]] * num_samples + for var_name in model.categorical_dims.keys() + } + ) + generated_samples_df["timeseries"] = list(generated_samples) + generated_samples_df["dataid"] = sample_row[ + "dataid" + ] # required for inverse transform + generated_samples_df = dataset.inverse_transform(generated_samples_df) + + # Extract month and weekday for plotting + month = sample_row.get("month", None) + weekday = sample_row.get("weekday", None) + + # Visualization 1: Plot range with synthetic values + range_plot = plot_range_with_syn_values( + real_data_df, generated_samples_df, month, weekday + ) + writer.add_figure(f"Visualizations/Range_Plot_{i}", range_plot) + + # Visualization 2: Plot closest real signals with synthetic values + closest_plot = plot_syn_with_closest_real_ts( + real_data_df, generated_samples_df, month, weekday + ) + writer.add_figure(f"Visualizations/Closest_Real_TS_{i}", closest_plot) + + # Visualization 3: KDE plots for real and synthetic data + real_data_array = np.stack(real_data_df["timeseries"]) + syn_data_array = np.stack(syn_data_df["timeseries"]) + kde_plot = visualization(real_data_array, syn_data_array, "kernel") + writer.add_figure(f"Visualizations/KDE", kde_plot) + def get_trained_model(self, dataset: Any) -> Any: """ Get a trained model for the dataset. diff --git a/main.py b/main.py index 620f9b2..ce80093 100644 --- a/main.py +++ b/main.py @@ -38,9 +38,11 @@ def evaluate_single_dataset_model( # evaluator.evaluate_all_users() # evaluator.evaluate_all_non_pv_users() non_pv_user_evaluator.evaluate_model( - None, distinguish_rare=True, data_label="non_pv_users" + None, distinguish_rare=False, data_label="non_pv_users" + ) + pv_user_evaluator.evaluate_model( + None, distinguish_rare=False, data_label="pv_users" ) - pv_user_evaluator.evaluate_model(None, distinguish_rare=True, data_label="pv_users") def main(): @@ -48,7 +50,7 @@ def main(): # evaluate_individual_user_models("acgan", include_generation=True) # evaluate_individual_user_models("acgan", include_generation=False, normalization_method="date") evaluate_single_dataset_model( - "diffusion_ts", + "acgan", geography="california", include_generation=False, normalization_method="group",