diff --git a/DatasetLoader.py b/DatasetLoader.py index 9ec2672..47e7c11 100644 --- a/DatasetLoader.py +++ b/DatasetLoader.py @@ -229,16 +229,33 @@ def __iter__(self): flattened_list.append([data[i] for i in indices]) ## Mix data in random order - mixid = torch.randperm(len(flattened_label), generator=g).tolist() + mixid = torch.randperm(len(flattened_label), generator=g).tolist() # numpy.arange(len(flattened_label)).tolist() mixlabel = [] mixmap = [] - - ## Prevent two pairs of the same speaker in the same batch - for ii in mixid: - startbatch = round_down(len(mixlabel), self.batch_size) - if flattened_label[ii] not in mixlabel[startbatch:]: - mixlabel.append(flattened_label[ii]) - mixmap.append(ii) + resmixid = [] + mixlabel_ins = 1 # for start while + + # ## Prevent two pairs of the same speaker in the same batch + # for ii in mixid: + # startbatch = round_down(len(mixlabel), self.batch_size) + # if flattened_label[ii] not in mixlabel[startbatch:]: + # mixlabel.append(flattened_label[ii]) + # mixmap.append(ii) + # mixlabel_ins += 1 + + ## Prevent two pairs of the same speaker in the same batch (Reduce data waste with "resmixid") + while len(mixid)>0 and mixlabel_ins>0: + mixlabel_ins = 0 + for ii in mixid: + startbatch = round_down(len(mixlabel), self.batch_size) + if flattened_label[ii] not in mixlabel[startbatch:]: + mixlabel.append(flattened_label[ii]) + mixmap.append(ii) + mixlabel_ins += 1 + else: + resmixid.append(ii) + mixid = resmixid + resmixid = [] mixed_list = [flattened_list[i] for i in mixmap]