Skip to content

Commit

Permalink
linting and adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Sep 12, 2024
1 parent 1b146e4 commit e2bd62b
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 21 deletions.
14 changes: 10 additions & 4 deletions eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@

from eval.discriminative_metric import discriminative_score_metrics
from eval.metrics import (
Context_FID, calculate_mmd, calculate_period_bound_mse, dynamic_time_warping_dist,
plot_range_with_syn_values, plot_syn_with_closest_real_ts, visualization,)
Context_FID,
calculate_mmd,
calculate_period_bound_mse,
dynamic_time_warping_dist,
plot_range_with_syn_values,
plot_syn_with_closest_real_ts,
visualization,
)
from eval.predictive_metric import predictive_score_metrics
from generator.diffcharge.diffusion import DDPM
from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS
from generator.gan.acgan import ACGAN
from generator.llm.llm import HF, GPT
from generator.llm.llm import GPT, HF
from generator.options import Options

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -300,7 +306,7 @@ def get_trained_model_for_user(self, model_name: str, user_dataset: Any) -> Any:
"diffusion_ts": Diffusion_TS,
"mistral": lambda opt: HF("mistralai/Mistral-7B-Instruct-v0.2"),
"llama": lambda opt: HF("meta-llama/Meta-Llama-3.1-8B"),
"gpt": lambda opt: GPT("gpt-4o")
"gpt": lambda opt: GPT("gpt-4o"),
}

if model_name in model_dict:
Expand Down
6 changes: 5 additions & 1 deletion eval/t2vec/t2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from eval.loss import hierarchical_contrastive_loss
from eval.t2vec.encoder import TSEncoder
from eval.t2vec.utils import (
centerize_vary_length_series, split_with_nan, take_per_row, torch_pad_nan,)
centerize_vary_length_series,
split_with_nan,
take_per_row,
torch_pad_nan,
)


class TS2Vec:
Expand Down
8 changes: 7 additions & 1 deletion generator/diffusion_ts/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from torch import nn

from generator.diffusion_ts.model_utils import (
GELU2, AdaLayerNorm, Conv_MLP, LearnablePositionalEncoding, Transpose, series_decomp,)
GELU2,
AdaLayerNorm,
Conv_MLP,
LearnablePositionalEncoding,
Transpose,
series_decomp,
)


class TrendBlock(nn.Module):
Expand Down
8 changes: 3 additions & 5 deletions generator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def __init__(self, name=DEFAULT_MODEL, sep=","):

self.tokenizer.add_special_tokens(special_tokens_dict)

self.tokenizer.pad_token = (
self.tokenizer.eos_token
)
self.tokenizer.pad_token = self.tokenizer.eos_token

valid_tokens = [
self.tokenizer.convert_tokens_to_ids(str(digit)) for digit in VALID_NUMBERS
Expand Down Expand Up @@ -117,7 +115,7 @@ def generate_timeseries(
values = response.split(self.sep)
values = [
v.strip() for v in values if v.strip().replace(".", "", 1).isdigit()
]
]
processed_responses.append(self.sep.join(values))

return processed_responses
Expand Down Expand Up @@ -258,4 +256,4 @@ def load_prompt_template(path):
"""Load the prompt template from a JSON file."""
with open(path) as f:
template = json.load(f)
return template
return template
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

from dataclasses.openpower import OpenPowerDataset
from dataclasses.pecanstreet import PecanStreetDataset

import torch

from eval.evaluator import Evaluator


Expand Down
20 changes: 12 additions & 8 deletions tests/test_endata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,31 @@
"""Tests for `endata` package."""

import unittest
import torch
from unittest.mock import MagicMock, patch

import numpy as np
from unittest.mock import patch, MagicMock
import pandas as pd
import torch

from eval.evaluator import Evaluator
from datasets.pecanstreet import PecanStreetDataset
from eval.evaluator import Evaluator
from generator.gan.acgan import ACGAN
from generator.options import Options


class TestGenerator(unittest.TestCase):
"""Test ACGAN Generator."""

def test_generator_output_shape(self):
opt = Options("acgan")
model = ACGAN(opt)
noise = torch.randn(opt.batch_size, opt.noise_dim).to(opt.device) # Batch of 32 samples with noise_dim=128
noise = torch.randn(opt.batch_size, opt.noise_dim).to(
opt.device
) # Batch of 32 samples with noise_dim=128
month_labels = torch.randint(0, 12, (opt.batch_size,)).to(opt.device)
day_labels = torch.randint(0, 7, (opt.batch_size,)).to(opt.device)

generated_data = model.generator(noise, month_labels, day_labels).to(opt.device)
self.assertEqual(generated_data.shape, (opt.batch_size, opt.seq_len, opt.input_dim)) # Check if output shape is as expected

generated_data = model.generator(noise, month_labels, day_labels).to(opt.device)
self.assertEqual(
generated_data.shape, (opt.batch_size, opt.seq_len, opt.input_dim)
) # Check if output shape is as expected

0 comments on commit e2bd62b

Please sign in to comment.