diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py index 25dfc12b57..c67a5ae78f 100644 --- a/ppcls/data/dataloader/multilabel_dataset.py +++ b/ppcls/data/dataloader/multilabel_dataset.py @@ -28,6 +28,7 @@ class MultiLabelDataset(CommonDataset): def _load_anno(self, label_ratio=False): assert os.path.exists(self._cls_path) assert os.path.exists(self._img_root) + self.label_ratio = label_ratio self.images = [] self.labels = [] with open(self._cls_path) as fd: @@ -41,7 +42,7 @@ def _load_anno(self, label_ratio=False): self.labels.append(labels) assert os.path.exists(self.images[-1]) - if label_ratio: + if self.label_ratio is not False: return np.array(self.labels).mean(0).astype("float32") def __getitem__(self, idx): @@ -52,7 +53,7 @@ def __getitem__(self, idx): img = transform(img, self._transform_ops) img = img.transpose((2, 0, 1)) label = np.array(self.labels[idx]).astype("float32") - if self.label_ratio is not None: + if self.label_ratio is not False: return (img, np.array([label, self.label_ratio])) else: return (img, label)