From e2bd62b48f3de159b89d7fad102a05f9f6f3e95b Mon Sep 17 00:00:00 2001 From: Michael Fuest Date: Thu, 12 Sep 2024 16:21:05 -0400 Subject: [PATCH] linting and adding tests --- eval/evaluator.py | 14 ++++++++++---- eval/t2vec/t2vec.py | 6 +++++- generator/diffusion_ts/transformer.py | 8 +++++++- generator/llm/llm.py | 8 +++----- main.py | 5 +++-- tests/test_endata.py | 20 ++++++++++++-------- 6 files changed, 40 insertions(+), 21 deletions(-) diff --git a/eval/evaluator.py b/eval/evaluator.py index 729ff50..77fcaae 100644 --- a/eval/evaluator.py +++ b/eval/evaluator.py @@ -7,13 +7,19 @@ from eval.discriminative_metric import discriminative_score_metrics from eval.metrics import ( - Context_FID, calculate_mmd, calculate_period_bound_mse, dynamic_time_warping_dist, - plot_range_with_syn_values, plot_syn_with_closest_real_ts, visualization,) + Context_FID, + calculate_mmd, + calculate_period_bound_mse, + dynamic_time_warping_dist, + plot_range_with_syn_values, + plot_syn_with_closest_real_ts, + visualization, +) from eval.predictive_metric import predictive_score_metrics from generator.diffcharge.diffusion import DDPM from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS from generator.gan.acgan import ACGAN -from generator.llm.llm import HF, GPT +from generator.llm.llm import GPT, HF from generator.options import Options device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -300,7 +306,7 @@ def get_trained_model_for_user(self, model_name: str, user_dataset: Any) -> Any: "diffusion_ts": Diffusion_TS, "mistral": lambda opt: HF("mistralai/Mistral-7B-Instruct-v0.2"), "llama": lambda opt: HF("meta-llama/Meta-Llama-3.1-8B"), - "gpt": lambda opt: GPT("gpt-4o") + "gpt": lambda opt: GPT("gpt-4o"), } if model_name in model_dict: diff --git a/eval/t2vec/t2vec.py b/eval/t2vec/t2vec.py index cecaab8..e859e67 100644 --- a/eval/t2vec/t2vec.py +++ b/eval/t2vec/t2vec.py @@ -19,7 +19,11 @@ from eval.loss import hierarchical_contrastive_loss from eval.t2vec.encoder import TSEncoder from eval.t2vec.utils import ( - centerize_vary_length_series, split_with_nan, take_per_row, torch_pad_nan,) + centerize_vary_length_series, + split_with_nan, + take_per_row, + torch_pad_nan, +) class TS2Vec: diff --git a/generator/diffusion_ts/transformer.py b/generator/diffusion_ts/transformer.py index 5c0c127..296dccc 100644 --- a/generator/diffusion_ts/transformer.py +++ b/generator/diffusion_ts/transformer.py @@ -20,7 +20,13 @@ from torch import nn from generator.diffusion_ts.model_utils import ( - GELU2, AdaLayerNorm, Conv_MLP, LearnablePositionalEncoding, Transpose, series_decomp,) + GELU2, + AdaLayerNorm, + Conv_MLP, + LearnablePositionalEncoding, + Transpose, + series_decomp, +) class TrendBlock(nn.Module): diff --git a/generator/llm/llm.py b/generator/llm/llm.py index 1540665..22b00ab 100644 --- a/generator/llm/llm.py +++ b/generator/llm/llm.py @@ -56,9 +56,7 @@ def __init__(self, name=DEFAULT_MODEL, sep=","): self.tokenizer.add_special_tokens(special_tokens_dict) - self.tokenizer.pad_token = ( - self.tokenizer.eos_token - ) + self.tokenizer.pad_token = self.tokenizer.eos_token valid_tokens = [ self.tokenizer.convert_tokens_to_ids(str(digit)) for digit in VALID_NUMBERS @@ -117,7 +115,7 @@ def generate_timeseries( values = response.split(self.sep) values = [ v.strip() for v in values if v.strip().replace(".", "", 1).isdigit() - ] + ] processed_responses.append(self.sep.join(values)) return processed_responses @@ -258,4 +256,4 @@ def load_prompt_template(path): """Load the prompt template from a JSON file.""" with open(path) as f: template = json.load(f) - return template \ No newline at end of file + return template diff --git a/main.py b/main.py index c21649e..d6c1764 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,8 @@ -import torch - from dataclasses.openpower import OpenPowerDataset from dataclasses.pecanstreet import PecanStreetDataset + +import torch + from eval.evaluator import Evaluator diff --git a/tests/test_endata.py b/tests/test_endata.py index 1663cbd..3eb749a 100644 --- a/tests/test_endata.py +++ b/tests/test_endata.py @@ -4,27 +4,31 @@ """Tests for `endata` package.""" import unittest -import torch +from unittest.mock import MagicMock, patch + import numpy as np -from unittest.mock import patch, MagicMock import pandas as pd +import torch -from eval.evaluator import Evaluator from datasets.pecanstreet import PecanStreetDataset +from eval.evaluator import Evaluator from generator.gan.acgan import ACGAN from generator.options import Options class TestGenerator(unittest.TestCase): """Test ACGAN Generator.""" - + def test_generator_output_shape(self): opt = Options("acgan") model = ACGAN(opt) - noise = torch.randn(opt.batch_size, opt.noise_dim).to(opt.device) # Batch of 32 samples with noise_dim=128 + noise = torch.randn(opt.batch_size, opt.noise_dim).to( + opt.device + ) # Batch of 32 samples with noise_dim=128 month_labels = torch.randint(0, 12, (opt.batch_size,)).to(opt.device) day_labels = torch.randint(0, 7, (opt.batch_size,)).to(opt.device) - - generated_data = model.generator(noise, month_labels, day_labels).to(opt.device) - self.assertEqual(generated_data.shape, (opt.batch_size, opt.seq_len, opt.input_dim)) # Check if output shape is as expected + generated_data = model.generator(noise, month_labels, day_labels).to(opt.device) + self.assertEqual( + generated_data.shape, (opt.batch_size, opt.seq_len, opt.input_dim) + ) # Check if output shape is as expected