From 3591786b6507173323b7b657a98b9bdffabbee46 Mon Sep 17 00:00:00 2001 From: ycy Date: Mon, 27 May 2019 15:44:09 +0800 Subject: [PATCH] fix assertion 't >= 0 && t < n_classes' error in a very naive way. Need to find out the actual cause --- data_loader/data_loaders.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/data_loader/data_loaders.py b/data_loader/data_loaders.py index 61c614b..83c61ba 100644 --- a/data_loader/data_loaders.py +++ b/data_loader/data_loaders.py @@ -65,6 +65,10 @@ def __getitem__(self, index): x = np.clip(x, -1, 1) x = self.f2c(self.mulaw(x)) + # fix possible outlier + x = np.clip(x, 0, self.q_channels - 1) + t = np.clip(t, 0, self.q_channels - 1) + f = interp1d(self.hop_idx[:condition.shape[1]], condition, copy=False, axis=1) h = f(np.arange(pos + 1, pos + 1 + self.segment)) @@ -77,3 +81,13 @@ def __init__(self, steps, data_dir, batch_size, num_workers, **kwargs): self.data_dir = data_dir self.dataset = _WAVDataset(data_dir, batch_size * steps, **kwargs) super().__init__(self.dataset, batch_size, num_workers=num_workers) + + +if __name__ == '__main__': + mulaw = np_mulaw(256) + f2c = float2class(256) + + import numpy as np + + x = np.random.randn(10000) + print(f2c(mulaw(x)).max(), f2c(mulaw(x)).min())