Skip to content

Commit

Permalink
pre-commits
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Sep 18, 2024
1 parent 31a6431 commit f8671d1
Show file tree
Hide file tree
Showing 19 changed files with 56 additions and 63 deletions.
19 changes: 2 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies:
- flake8-isort
args: [--max-line-length=88]

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
Expand All @@ -25,12 +17,5 @@ repos:
rev: 24.8.0
hooks:
- id: black
args: [--line-length=88]
language_version: python3.8

- repo: https://github.com/pycqa/autoflake
rev: v1.6.0
hooks:
- id: autoflake
name: autoflake (remove unused imports and variables)
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variables']
types: [python]
6 changes: 4 additions & 2 deletions datasets/openpower.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import warnings
from typing import Dict, List, Tuple
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np
import pandas as pd
import torch
import yaml
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset

warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning)
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand Down
5 changes: 3 additions & 2 deletions eval/discriminative_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
Output: discriminative score (np.abs(classification accuracy - 0.5))
"""

from typing import List, Tuple
from typing import List
from typing import Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -109,7 +110,7 @@ def discriminative_score_metrics(
# Extract time information
ori_time, ori_max_seq_len = extract_time(ori_data)
generated_time, generated_max_seq_len = extract_time(generated_data)
max_seq_len = max(ori_max_seq_len, generated_max_seq_len)
max(ori_max_seq_len, generated_max_seq_len)

# Network parameters
hidden_dim = int(dim * 2)
Expand Down
29 changes: 21 additions & 8 deletions eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from typing import Any, Dict, List, Tuple
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter

from eval.discriminative_metric import discriminative_score_metrics
from eval.metrics import (
Context_FID, calculate_mmd, calculate_period_bound_mse, dynamic_time_warping_dist,
plot_range_with_syn_values, plot_syn_with_closest_real_ts, visualization,)
from eval.metrics import Context_FID
from eval.metrics import calculate_mmd
from eval.metrics import calculate_period_bound_mse
from eval.metrics import dynamic_time_warping_dist
from eval.metrics import plot_range_with_syn_values
from eval.metrics import plot_syn_with_closest_real_ts
from eval.metrics import visualization
from eval.predictive_metric import predictive_score_metrics
from generator.diffcharge.diffusion import DDPM
from generator.diffusion_ts.gaussian_diffusion import Diffusion_TS
from generator.gan.acgan import ACGAN
from generator.llm.llm import GPT, HF
from generator.llm.llm import GPT
from generator.llm.llm import HF
from generator.options import Options

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -458,9 +466,14 @@ def _log_final_results(self):
"""
Log the final evaluation results.
"""
dtw_mean, mmd_mean, mse_mean, fid_mean, discr_mean, pred_mean = (
self.get_summary_metrics()
)
(
dtw_mean,
mmd_mean,
mse_mean,
fid_mean,
discr_mean,
pred_mean,
) = self.get_summary_metrics()

self.writer.add_scalar("Final_Results/DTW", dtw_mean)
self.writer.add_scalar("Final_Results/MMD", mmd_mean)
Expand Down
1 change: 0 additions & 1 deletion eval/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


def compute_pairwise_distances(x, y):
Expand Down
5 changes: 3 additions & 2 deletions eval/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Callable, List, Tuple
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -10,7 +10,8 @@
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from eval.loss import gaussian_kernel_matrix, maximum_mean_discrepancy
from eval.loss import gaussian_kernel_matrix
from eval.loss import maximum_mean_discrepancy
from eval.t2vec.t2vec import TS2Vec


Expand Down
2 changes: 1 addition & 1 deletion eval/predictive_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def predictive_score_metrics(ori_data, generated_data):

ori_time, ori_max_seq_len = extract_time(ori_data)
generated_time, generated_max_seq_len = extract_time(generated_data)
max_seq_len = max([ori_max_seq_len, generated_max_seq_len])
max([ori_max_seq_len, generated_max_seq_len])

hidden_dim = max(int(dim / 2), 1)
iterations = 5000
Expand Down
2 changes: 0 additions & 2 deletions eval/t2vec/dilated_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
Note: Please ensure compliance with the repository's license and credit the original authors when using or distributing this code.
"""

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

Expand Down
1 change: 0 additions & 1 deletion eval/t2vec/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from eval.t2vec.dilated_conv import DilatedConvEncoder
Expand Down
9 changes: 6 additions & 3 deletions eval/t2vec/t2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

from eval.loss import hierarchical_contrastive_loss
from eval.t2vec.encoder import TSEncoder
from eval.t2vec.utils import (
centerize_vary_length_series, split_with_nan, take_per_row, torch_pad_nan,)
from eval.t2vec.utils import centerize_vary_length_series
from eval.t2vec.utils import split_with_nan
from eval.t2vec.utils import take_per_row
from eval.t2vec.utils import torch_pad_nan


class TS2Vec:
Expand Down
1 change: 0 additions & 1 deletion eval/t2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Note: Please ensure compliance with the repository's license and credit the original authors when using or distributing this code.
"""

import os
import pickle
import random
from datetime import datetime
Expand Down
1 change: 0 additions & 1 deletion generator/conditioning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConditioningModule(nn.Module):
Expand Down
6 changes: 1 addition & 5 deletions generator/diffusion_ts/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@
"""

import math
from functools import partial

import scipy
import torch
import torch.nn.functional as F
from einops import rearrange, reduce
from scipy.fftpack import next_fast_len
from torch import einsum, nn
from torch import nn


def exists(x):
Expand Down
12 changes: 9 additions & 3 deletions generator/diffusion_ts/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from einops import rearrange
from einops import reduce
from einops import repeat
from torch import nn

from generator.diffusion_ts.model_utils import (
GELU2, AdaLayerNorm, Conv_MLP, LearnablePositionalEncoding, Transpose, series_decomp,)
from generator.diffusion_ts.model_utils import GELU2
from generator.diffusion_ts.model_utils import AdaLayerNorm
from generator.diffusion_ts.model_utils import Conv_MLP
from generator.diffusion_ts.model_utils import LearnablePositionalEncoding
from generator.diffusion_ts.model_utils import Transpose
from generator.diffusion_ts.model_utils import series_decomp


class TrendBlock(nn.Module):
Expand Down
1 change: 0 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datasets.openpower import OpenPowerDataset
from datasets.pecanstreet import PecanStreetDataManager
from eval.evaluator import Evaluator

Expand Down
9 changes: 4 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ exclude = docs, .tox, .git, __pycache__, .ipynb_checkpoints
ignore = W503

[isort]
include_trailing_comma = True
line_length=99
lines_between_types = 0
multi_line_output = 4
use_parentheses = True
profile = black
line_length = 88
combine_as_imports = true
force_single_line = true

[aliases]
test = pytest
6 changes: 0 additions & 6 deletions tests/test_endata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@
import os
import unittest
from typing import List
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_squared_error

from datasets.pecanstreet import PecanStreetDataManager
from eval.evaluator import Evaluator
from generator.gan.acgan import ACGAN
from generator.options import Options

TEST_CONFIG_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_data_config.yaml"
Expand Down
2 changes: 1 addition & 1 deletion tests/testdata/15minute_data_austin.csv
Original file line number Diff line number Diff line change
Expand Up @@ -19986,4 +19986,4 @@ dataid,local_15min,air1,air2,air3,airwindowunit1,aquarium1,bathroom1,bathroom2,b
1642,2018-04-14 23:15:00-05,0.002,,,,,0.001,,,,,,,,-0.001,,,0.313,,,,0.000,,-0.001,0.000,,0.091,,,,1.333,,,,,,,,,0.000,0.001,,,,,,,0.492,,0.003,,,,0.008,,,,,,,,0.170,,,,,-0.009,,,,-0.007,0.000,,,,,123.188,123.062
1642,2018-04-14 23:30:00-05,0.002,,,,,0.002,,,,,,,,0.000,,,0.293,,,,0.000,,-0.002,0.000,,0.185,,,,1.071,,,,,,,,,0.000,0.001,,,,,,,0.219,,0.004,,,,0.008,,,,,,,,0.168,,,,,-0.011,,,,-0.018,0.010,,,,,123.381,123.072
1642,2018-04-14 23:45:00-05,0.001,,,,,0.002,,,,,,,,0.001,,,0.294,,,,0.000,,-0.002,0.000,,0.101,,,,0.787,,,,,,,,,0.000,0.001,,,,,,,0.030,,0.004,,,,0.008,,,,,,,,0.167,,,,,-0.011,,,,0.000,0.000,,,,,123.616,123.212
1642,2018-04-15 00:00:00-05,0.002,,,,,0.002,,,,,,,,0.001,,,0.296,,,,0.000,,-0.003,0.000,,0.231,,,,0.893,,,,,,,,,0.000,0.002,,,,,,,0.030,,0.004,,,,0.008,,,,,,,,0.166,,,,,-0.011,,,,0.000,0.000,,,,,123.829,123.346
1642,2018-04-15 00:00:00-05,0.002,,,,,0.002,,,,,,,,0.001,,,0.296,,,,0.000,,-0.003,0.000,,0.231,,,,0.893,,,,,,,,,0.000,0.002,,,,,,,0.030,,0.004,,,,0.008,,,,,,,,0.166,,,,,-0.011,,,,0.000,0.000,,,,,123.829,123.346
2 changes: 1 addition & 1 deletion tests/testdata/15minute_data_california.csv
Original file line number Diff line number Diff line change
Expand Up @@ -6714,4 +6714,4 @@ dataid,local_15min,air1,air2,air3,airwindowunit1,aquarium1,bathroom1,bathroom2,b
3687,2014-03-11 23:00:00-05,0.000,,,,,,,,,,,,,,,,,0.004,,,0.000,0.000,,,,0.004,,,,1.049,,,,,,,,,0.005,,,,,,,,0.423,,0.006,,,,0.001,,,,,,,,0.373,,,,,,,,,,,,,,,120.271,121.162
3687,2014-03-11 23:15:00-05,0.000,,,,,,,,,,,,,,,,,0.004,,,0.000,0.000,,,,0.004,,,,0.893,,,,,,,,,0.004,,,,,,,,0.435,,0.006,,,,0.000,,,,,,,,0.241,,,,,,,,,,,,,,,120.520,121.225
3687,2014-03-11 23:30:00-05,0.000,,,,,,,,,,,,,,,,,0.004,,,0.000,0.000,,,,0.004,,,,1.145,,,,,,,,,0.307,,,,,,,,0.465,,0.006,,,,0.002,,,,,,,,0.127,,,,,,,,,,,,,,,120.504,121.450
3687,2014-03-11 23:45:00-05,0.000,,,,,,,,,,,,,,,,,0.004,,,0.000,0.000,,,,0.004,,,,0.883,,,,,,,,,0.009,,,,,,,,0.422,,0.006,,,,0.000,,,,,,,,0.181,,,,,,,,,,,,,,,120.497,121.445
3687,2014-03-11 23:45:00-05,0.000,,,,,,,,,,,,,,,,,0.004,,,0.000,0.000,,,,0.004,,,,0.883,,,,,,,,,0.009,,,,,,,,0.422,,0.006,,,,0.000,,,,,,,,0.181,,,,,,,,,,,,,,,120.497,121.445

0 comments on commit f8671d1

Please sign in to comment.