Skip to content

Commit

Permalink
TST: add test
Browse files Browse the repository at this point in the history
  • Loading branch information
frankenjoe committed Mar 28, 2023
1 parent c6a49ce commit 4a338e5
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,29 +937,45 @@ def test_process_signal(


@pytest.mark.parametrize(
'index,expected_features',
'feature, signal, sampling_rate, index, expected',
[
(
audinterface.Feature(
feature_names=('o1', 'o2', 'o3'),
process_func=feature_extractor,
channels=range(NUM_CHANNELS),
),
SIGNAL_2D,
SAMPLING_RATE,
audinterface.utils.signal_index(
[pd.to_timedelta('0s'), pd.to_timedelta('1s')],
[pd.to_timedelta('2s'), pd.to_timedelta('3s')],
),
np.ones((2, NUM_CHANNELS * NUM_FEATURES)),
),
(
audinterface.Feature(
feature_names=('string'),
process_func=lambda x, sr, idx: ['a', 'abc'][idx],
),
SIGNAL_1D,
SAMPLING_RATE,
audinterface.utils.signal_index(
[pd.to_timedelta('0s'), pd.to_timedelta('1s')],
[pd.to_timedelta('2s'), pd.to_timedelta('3s')],
),
np.array([['a'], ['abc']]),
),
],
)
def test_process_signal_from_index(index, expected_features):
extractor = audinterface.Feature(
feature_names=('o1', 'o2', 'o3'),
process_func=feature_extractor,
channels=range(NUM_CHANNELS),
)
features = extractor.process_signal_from_index(
SIGNAL_2D,
SAMPLING_RATE,
def test_process_signal_from_index(feature, signal, sampling_rate, index,
expected):
df = feature.process_signal_from_index(
signal,
sampling_rate,
index,
)
np.testing.assert_array_equal(features.values, expected_features)
np.testing.assert_array_equal(df.values, expected)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 4a338e5

Please sign in to comment.