From fe459f4c2bdde76d1b8d05843616496e2a0a830e Mon Sep 17 00:00:00 2001 From: BreezeWhite Date: Mon, 10 May 2021 10:46:42 +0800 Subject: [PATCH] Fix passing wrong type of parameters --- omnizart/utils.py | 5 +++-- tests/test_utils.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/omnizart/utils.py b/omnizart/utils.py index f12d8b1..c8c5a8d 100644 --- a/omnizart/utils.py +++ b/omnizart/utils.py @@ -11,6 +11,7 @@ import jsonschema import pretty_midi +import numpy as np import scipy.io.wavfile as wave from omnizart.constants.midi import SOUNDFONT_PATH @@ -363,7 +364,7 @@ def aggregate_f0_info(pred, t_unit): start_idx = 0 last_hz = pred[0] eps = 1e-6 - pred.append(0) # Append an additional zero to the end temporarily. + pred = np.append(pred, 0) # Append an additional zero to the end temporarily. while cur_idx < len(pred): cur_hz = pred[cur_idx] if abs(cur_hz - last_hz) < eps: @@ -390,5 +391,5 @@ def aggregate_f0_info(pred, t_unit): cur_idx += 1 last_hz = cur_hz - del pred[-1] # Remove the additional ending zero. + pred = pred[:-1] # Remove the additional ending zero. return results diff --git a/tests/test_utils.py b/tests/test_utils.py index b0dd0b0..8d1fe84 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import shutil import pytest +import numpy as np from jsonschema import ValidationError from omnizart import utils @@ -187,7 +188,7 @@ def test_serializable_recursive_value_path(): def test_aggregate_f0_info(): t_unit = 0.01 - data = [0, 0, 0, 440, 440, 440, 440, 0, 0, 0, 220, 220] + data = np.array([0, 0, 0, 440, 440, 440, 440, 0, 0, 0, 220, 220]) expected = [ {"start_time": 0.03, "end_time": 0.07, "frequency": 440, "pitch": 69}, {"start_time": 0.1, "end_time": 0.12, "frequency": 220, "pitch": 57}