Skip to content

Commit

Permalink
Added unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaCerina committed Jun 23, 2022
1 parent 5a578ba commit 9a51539
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
2 changes: 1 addition & 1 deletion ampdlib/ampdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def ampd_fast(sig_input:np.ndarray, window_length:int, hop_length:int=None, lsm_

# Define iterations
if window_length < sig_input.shape[0]:
iterations = (sig_input.shape[0] - window_length) / hop_length + 1
iterations = int((sig_input.shape[0] - window_length) // hop_length) + 1
if hop_length is None:
hop_length = window_length
else:
Expand Down
80 changes: 57 additions & 23 deletions ampdlib/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,64 @@
__copyright__ = "Copyright 2022, Luca Cerina"
__email__ = "[email protected]"

import sys
from time import perf_counter
# import sys
# from time import perf_counter
# import numpy as np
# from scipy.io import loadmat
# sys.path.append('../')
# import ampdLib

# if __name__ == "__main__":
# print("Loading data")
# real_data = loadmat('data.mat')

# print("Find peaks")
# N = 30000
# input_data = real_data['ecg_signal'][0:N,0].flatten()
# tStart = perf_counter()
# ampd_peaks = ampdLib.ampd_fast(input_data)
# tEnd = perf_counter()
# print("Samples: {}".format(N))
# print("Peaks found in: {} seconds".format(tEnd-tStart))

# print("Test results")
# # Find misplaced peaks
# error_peaks = np.sum((ampd_peaks - real_data['real_peaks'][0][0:ampd_peaks.shape[0]]) != 0)
# if(error_peaks == 0):
# print("Test passed")
# else:
# print("Total error {}".format(error_peaks))

import unittest

import ampdlib
import numpy as np
from scipy.io import loadmat
sys.path.append('../')
import ampdLib


class TestLibrary(unittest.TestCase):
def setUp(self) -> None:
print("Loading data")
self.test_data = loadmat('data.mat')
N = 30000
self.input_data = self.test_data['ecg_signal'][0:N,0].flatten()
self.real_peaks = self.test_data['real_peaks'][0]
return super().setUp()

def test_detection(self):
ampd_peaks = ampdlib.ampd_fast(self.input_data, window_length=2000, hop_length=1000)
error_peaks = np.sum((ampd_peaks - self.real_peaks[0:ampd_peaks.shape[0]]) != 0)
self.assertEqual(error_peaks, 0)

def test_assertions(self):
with self.assertRaises(AssertionError):
ampdlib.ampd(self.input_data, lsm_limit=-1)
with self.assertRaises(AssertionError):
ampdlib.ampd(self.input_data, lsm_limit=2)
with self.assertRaises(AssertionError):
ampdlib.ampd_fast(self.input_data, 2000, -1)
with self.assertRaises(AssertionError):
ampdlib.ampd_fast(self.input_data, 2000, 2100)

if __name__ == "__main__":
print("Loading data")
real_data = loadmat('data.mat')

print("Find peaks")
N = 30000
input_data = real_data['ecg_signal'][0:N,0].flatten()
tStart = perf_counter()
ampd_peaks = ampdLib.ampd_fast(input_data)
tEnd = perf_counter()
print("Samples: {}".format(N))
print("Peaks found in: {} seconds".format(tEnd-tStart))

print("Test results")
# Find misplaced peaks
error_peaks = np.sum((ampd_peaks - real_data['real_peaks'][0][0:ampd_peaks.shape[0]]) != 0)
if(error_peaks == 0):
print("Test passed")
else:
print("Total error {}".format(error_peaks))
unittest.main()

0 comments on commit 9a51539

Please sign in to comment.