Skip to content

Commit

Permalink
Fix: do not truncate string features (#110)
Browse files Browse the repository at this point in the history
* fix feature extraction when return value is a string

* fix flake

* TST: add test
  • Loading branch information
frankenjoe authored Mar 29, 2023
1 parent a242a55 commit 755b4d6
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
11 changes: 2 additions & 9 deletions audinterface/core/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,6 @@ def _series_to_frame(
dtype=object,
)

num = len(y)

if (
self.win_dur is not None and
self.process_func_applies_sliding_window
Expand Down Expand Up @@ -825,17 +823,12 @@ def _series_to_frame(

index = utils.signal_index(starts, ends)

data = np.concatenate(data)

else:

index = y.index
dtype = self._values_to_frame(y[0]).dtype
shape = (num, len(self.column_names))
data = np.empty(shape, dtype)
data = [self._values_to_frame(values) for values in y]

for idx, values in enumerate(y):
data[idx, :] = self._values_to_frame(values)
data = np.concatenate(data)

df = pd.DataFrame(
data,
Expand Down
38 changes: 27 additions & 11 deletions tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,29 +937,45 @@ def test_process_signal(


@pytest.mark.parametrize(
'index,expected_features',
'feature, signal, sampling_rate, index, expected',
[
(
audinterface.Feature(
feature_names=('o1', 'o2', 'o3'),
process_func=feature_extractor,
channels=range(NUM_CHANNELS),
),
SIGNAL_2D,
SAMPLING_RATE,
audinterface.utils.signal_index(
[pd.to_timedelta('0s'), pd.to_timedelta('1s')],
[pd.to_timedelta('2s'), pd.to_timedelta('3s')],
),
np.ones((2, NUM_CHANNELS * NUM_FEATURES)),
),
(
audinterface.Feature(
feature_names=('string'),
process_func=lambda x, sr, idx: ['a', 'abc'][idx],
),
SIGNAL_1D,
SAMPLING_RATE,
audinterface.utils.signal_index(
[pd.to_timedelta('0s'), pd.to_timedelta('1s')],
[pd.to_timedelta('2s'), pd.to_timedelta('3s')],
),
np.array([['a'], ['abc']]),
),
],
)
def test_process_signal_from_index(index, expected_features):
extractor = audinterface.Feature(
feature_names=('o1', 'o2', 'o3'),
process_func=feature_extractor,
channels=range(NUM_CHANNELS),
)
features = extractor.process_signal_from_index(
SIGNAL_2D,
SAMPLING_RATE,
def test_process_signal_from_index(feature, signal, sampling_rate, index,
expected):
df = feature.process_signal_from_index(
signal,
sampling_rate,
index,
)
np.testing.assert_array_equal(features.values, expected_features)
np.testing.assert_array_equal(df.values, expected)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 755b4d6

Please sign in to comment.