Skip to content

Commit

Permalink
FIX raise ImportError when data is not available (#27)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Moreau <[email protected]>
Jad-yehya and tomMoral authored Nov 26, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 9322a74 commit 6727b4c
Showing 4 changed files with 70 additions and 46 deletions.
50 changes: 50 additions & 0 deletions benchmark_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
# the usual import syntax

from benchopt import safe_import_context
from pathlib import Path

with safe_import_context() as import_ctx:
import numpy as np
@@ -43,3 +44,52 @@ def mean_overlaping_pred(predictions, stride):
averaged_predictions = accumulated / counts

return averaged_predictions


def check_data(data_path, dataset, data_type):
"""
Checks if the data is present in the specified path.
Args:
data_path: str
The path to the data directory.
dataset: str
The name of the dataset, either 'WADI' or 'SWaT'.
data_type: str
The type of data, either 'train' or 'test'.
Raises:
ImportError: If the required data files are not found.
"""
if dataset == "WADI":
if data_type == "train":
required_files = ["WADI_14days_new.csv"]
elif data_type == "test":
required_files = ["WADI_attackdataLABLE.csv"]
else:
raise ValueError("data_type must be either 'train' or 'test'")
elif dataset == "SWaT":
if data_type == "train":
required_files = ["swat_train2.csv"]
elif data_type == "test":
required_files = ["swat2.csv"]
else:
raise ValueError("data_type must be either 'train' or 'test'")
else:
raise ValueError("dataset must be either 'WADI' or 'SWaT'")

for file in required_files:
if not Path(data_path, file).exists():
official_repo = {
"WADI": "https://itrust.sutd.edu.sg/itrust-labs_datasets/\
dataset_info/",
"SWaT": "https://drive.google.com/drive/folders/\
1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w"
}
raise ImportError(
f"{data_type.capitalize()} data not found for {dataset}. "
"Please download the data "
"from the official repository "
f"{official_repo[dataset]}"
f"and place it in {data_path}"
)
6 changes: 3 additions & 3 deletions datasets/msl.py
Original file line number Diff line number Diff line change
@@ -48,9 +48,9 @@ def get_data(self):
with open(path / "MSL_test_label.npy", "wb") as f:
f.write(response.content)

X_train = np.load(path / "MSL_train.npy")
X_test = np.load(path / "MSL_test.npy")
y_test = np.load(path / "MSL_test_label.npy")
X_train = np.load(path / "MSL_train.npy", allow_pickle=True)
X_test = np.load(path / "MSL_test.npy", allow_pickle=True)
y_test = np.load(path / "MSL_test_label.npy", allow_pickle=True)

# Limiting the size of the dataset for testing purposes
if self.debug:
30 changes: 8 additions & 22 deletions datasets/swat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from benchopt import BaseDataset, safe_import_context
from benchopt.config import get_data_path
from benchmark_utils import check_data

with safe_import_context() as import_ctx:
import pandas as pd

# Checking if the data is available
PATH = get_data_path(key="SWaT")
check_data(PATH, "SWaT", "train")
check_data(PATH, "SWaT", "test")


class Dataset(BaseDataset):
name = "SWaT"
@@ -21,29 +27,9 @@ def get_data(self):
# at the following link:
# https://drive.google.com/drive/folders/1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w

path = get_data_path(key="SWaT")

if not (path / "swat_train2.csv").exists():
raise FileNotFoundError(
"Train data not found. Please download the data "
"from the Google Drive "
"https://drive.google.com/drive/folders/"
"1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w"
f" and place it in {path}"
)

if not (path / "swat2.csv").exists():
raise FileNotFoundError(
"Test data not found. Please download the data "
"from the Google Drive "
"https://drive.google.com/drive/folders/"
"1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w"
f" and place it in {path}"
)

# Load the data
X_train = pd.read_csv(path / "swat_train2.csv")
X_test = pd.read_csv(path / "swat2.csv")
X_train = pd.read_csv(PATH / "swat_train2.csv")
X_test = pd.read_csv(PATH / "swat2.csv")

# Extract the target
y_test = X_test["Normal/Attack"].values
30 changes: 9 additions & 21 deletions datasets/wadi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from benchopt import BaseDataset, safe_import_context
from benchopt.config import get_data_path
from benchmark_utils import check_data

with safe_import_context() as import_ctx:
import pandas as pd

# Checking if the data is available
PATH = get_data_path(key="WADI")
check_data(PATH, "WADI", "train")
check_data(PATH, "WADI", "test")


class Dataset(BaseDataset):
name = "WADI"
@@ -21,27 +27,9 @@ def get_data(self):
# at the following link:
# https://itrust.sutd.edu.sg/itrust-labs_datasets/dataset_info/

path = get_data_path(key="WADI")

if not (path / "WADI_14days_new.csv").exists():
raise FileNotFoundError(
"Train data not found. Please download the data "
"from the official repository"
"https://itrust.sutd.edu.sg/itrust-labs_datasets/dataset_info/"
f"and place it in {path}"
)

if not (path / "WADI_attackdataLABLE.csv").exists():
raise FileNotFoundError(
"Test data not found. Please download the data "
"from the official repository"
"https://itrust.sutd.edu.sg/itrust-labs_datasets/dataset_info/"
f"and place it in {path}"
)

# Load the data
X_train = pd.read_csv(path / "WADI_14days_new.csv")
X_test = pd.read_csv(path / "WADI_attackdataLABLE.csv", header=1)
X_train = pd.read_csv(PATH / "WADI_14days_new.csv")
X_test = pd.read_csv(PATH / "WADI_attackdataLABLE.csv", header=1)

# Data processing
# Dropping the following colummns because more than 50% of the values
@@ -64,7 +52,7 @@ def get_data(self):
y_test = X_test["Attack LABLE (1:No Attack, -1:Attack)"].values
X_test.drop(
columns=todrop + [
"Attack LABLE (1:No Attack, -1:Attack)"],
"Attack LABLE (1:No Attack, -1:Attack)"],
inplace=True
)
# Using ffill to fill the missing values because

0 comments on commit 6727b4c

Please sign in to comment.