diff --git a/batchflow/sampler.py b/batchflow/sampler.py index cb69843f2..59413d59c 100644 --- a/batchflow/sampler.py +++ b/batchflow/sampler.py @@ -235,8 +235,13 @@ def sample(self, size): up_sample = self.bases[0].sample(size=up_size) low_sample = self.bases[1].sample(size=low_size) - sample_points = np.concatenate([up_sample, low_sample]) - sample_points = sample_points[np.random.permutation(size)] + if len(up_sample) > 0 and len(low_sample) > 0: + sample_points = np.concatenate([up_sample, low_sample]) + sample_points = sample_points[np.random.permutation(size)] + elif len(up_sample) == 0: + sample_points = low_sample + else: + sample_points = up_sample return sample_points