forked from chen0040/keras-anomaly-detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlstm_autoencoder.py
37 lines (28 loc) · 1.27 KB
/
lstm_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from keras_anomaly_detection.library.plot_utils import visualize_reconstruction_error
from keras_anomaly_detection.library.recurrent import LstmAutoEncoder
DO_TRAINING = False
def main():
data_dir_path = './data'
model_dir_path = './models'
ecg_data = pd.read_csv(data_dir_path + '/ecg_discord_test.csv', header=None)
print(ecg_data.head())
ecg_np_data = ecg_data.as_matrix()
scaler = MinMaxScaler()
ecg_np_data = scaler.fit_transform(ecg_np_data)
print(ecg_np_data.shape)
ae = LstmAutoEncoder()
# fit the data and save model into model_dir_path
if DO_TRAINING:
ae.fit(ecg_np_data[:23, :], model_dir_path=model_dir_path, estimated_negative_sample_ratio=0.9)
# load back the model saved in model_dir_path detect anomaly
ae.load_model(model_dir_path)
anomaly_information = ae.anomaly(ecg_np_data[:23, :])
reconstruction_error = []
for idx, (is_anomaly, dist) in enumerate(anomaly_information):
print('# ' + str(idx) + ' is ' + ('abnormal' if is_anomaly else 'normal') + ' (dist: ' + str(dist) + ')')
reconstruction_error.append(dist)
visualize_reconstruction_error(reconstruction_error, ae.threshold)
if __name__ == '__main__':
main()