forked from jpordoy/AMBER
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
57 lines (42 loc) · 1.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import pandas as pd
import numpy as np
from data_loader import DataLoader
from data_formatter import DataFormatter
from event_splitter import EventSplitter
from model_rf import Amber_RF
#from model import Amber
from kfold_cv import KFoldCrossValidation
from evaluator import evaluate_model_performance
from config import config
# Define your DataFrame and parameter
mypath = 'Data/Train.csv'
df = pd.read_csv(mypath)
target_column = 'label' # Name of the target column
# Step 1: Load Data
data_loader = DataLoader(dataframe=df, time_steps=config.N_TIME_STEPS, step=config.step, target_column=target_column)
segments, labels = data_loader.load_data()
# Step 2: Format Data
data_formatter = DataFormatter(config=config)
X_train_reshaped, X_test_reshaped, y_train, y_test = data_formatter.format_data_by_event(segments, labels)
# Reshape y_test correctly
y_test_reshaped = np.asarray(y_test, dtype=np.float32)
# Initialize model with concat
#ts_model = Amber(row_hidden=config.row_hidden, col_hidden=config.row_hidden, num_classes=config.N_CLASSES)
# Initialize model with residual fusion layer
#ts_model = Amber_RF(row_hidden=config.row_hidden, col_hidden=config.row_hidden, num_classes=config.N_CLASSES)
# Initialize model
ts_model = PANN(row_hidden=config.row_hidden, col_hidden=config.row_hidden, num_classes=config.N_CLASSES)
# Create an instance of KFoldCrossValidation
kfold_cv = KFoldCrossValidation(ts_model, [X_train_reshaped['Feature_1'], X_train_reshaped['Feature_2']], y_train)
# Run the cross-validation
kfold_cv.run()
# Evaluate the model performance
evaluation_results = evaluate_model_performance(ts_model, [X_test_reshaped['Feature_1'], X_test_reshaped['Feature_2']], y_test_reshaped)
# Access individual metrics
print("Accuracy:", evaluation_results["accuracy"])
print("F1 Score:", evaluation_results["f1"])
print("Cohen's Kappa:", evaluation_results["cohen_kappa"])
print("MCC:", evaluation_results["mcc"])
print()
print("Confusion Matrix:\n", evaluation_results["confusion_matrix"])
print()