diff --git a/partitura/performance.py b/partitura/performance.py index a04e726b..95d91ed4 100644 --- a/partitura/performance.py +++ b/partitura/performance.py @@ -362,7 +362,7 @@ def __str__(self): return f"PerformedNote: {self['id']}" def __eq__(self, other): - if not isinstance(PerformedNote): + if not isinstance(other, PerformedNote): return False if not self.keys() == other.keys(): return False diff --git a/tests/test_performance.py b/tests/test_performance.py index 2c864348..ab40ae28 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -6,7 +6,7 @@ import unittest import numpy as np -from partitura.performance import PerformedPart, Performance +from partitura.performance import PerformedPart, Performance, PerformedNote RNG = np.random.RandomState(1984) @@ -138,3 +138,195 @@ def generate_random_note_array(n_notes=100, beat_period=0.5, n_tracks=3): idx = track_idxs[i * track_length :] note_array["track"][idx] = i return note_array + + +class TestPerformedNote(unittest.TestCase): + def test_initialization(self): + """ + Test that the notes are initialized correctly + """ + pnote = PerformedNote( + { + "pitch": 60, + "note_on": 10, + "note_off": 20, + } + ) + self.assertEqual(pnote["pitch"], 60) + self.assertEqual(pnote["note_on"], 10) + self.assertEqual(pnote["note_off"], 20) + + # Test initialization to default values + self.assertEqual(pnote["track"], 0) + self.assertEqual(pnote["channel"], 1) + self.assertEqual(pnote["velocity"], 60) + self.assertEqual(pnote["sound_off"], 20) + + def test_validate_values(self): + """ + Test for _validate_values + """ + + # Invalid pitch + with self.assertRaises(ValueError): + PerformedNote( + { + "pitch": 128, + "note_on": 10, + "note_off": 20, + } + ) + + # invalid note on + with self.assertRaises(ValueError): + PerformedNote( + { + "pitch": 60, + "note_on": -5, + "note_off": 20, + } + ) + + # invalid note off (note off < note on) + with self.assertRaises(ValueError): + PerformedNote( + { + "pitch": 60, + "note_on": 10, + "note_off": 5, + } + ) + + # invalid velocity (>127) + with self.assertRaises(ValueError): + PerformedNote( + { + "pitch": 60, + "note_on": 10, + "note_off": 20, + "velocity": 150, + } + ) + + # validate sound off ( pnote1) + self.assertTrue(pnote2 >= pnote1) + + def test_get_and_set_item(self): + pnote = PerformedNote( + { + "pitch": 60, + "note_on": 10, + "note_off": 20, + } + ) + pnote["velocity"] = 80 + self.assertEqual(pnote["velocity"], 80) + + with self.assertRaises(KeyError): + pnote["invalid_key"] = 100 + + def test_copy(self): + pnote = PerformedNote( + { + "pitch": 60, + "note_on": 10, + "note_off": 20, + } + ) + pnote_copy = pnote.copy() + self.assertEqual(pnote, pnote_copy) + self.assertIsNot(pnote, pnote_copy) + + def test_cannot_delete_items(self): + pnote = PerformedNote( + { + "pitch": 60, + "note_on": 10, + "note_off": 20, + } + ) + with self.assertRaises(KeyError): + del pnote["pitch"] + +