Skip to content

Commit

Permalink
Merge pull request #2138 from cuicheng01/release/2.4
Browse files Browse the repository at this point in the history
update multilabel_dataset.py
  • Loading branch information
cuicheng01 authored Jul 7, 2022
2 parents b676850 + dcd90c5 commit 3a28ee2
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ppcls/data/dataloader/multilabel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class MultiLabelDataset(CommonDataset):
def _load_anno(self, label_ratio=False):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
self.label_ratio = label_ratio
self.images = []
self.labels = []
with open(self._cls_path) as fd:
Expand All @@ -41,7 +42,7 @@ def _load_anno(self, label_ratio=False):

self.labels.append(labels)
assert os.path.exists(self.images[-1])
if label_ratio:
if self.label_ratio is not False:
return np.array(self.labels).mean(0).astype("float32")

def __getitem__(self, idx):
Expand All @@ -52,7 +53,7 @@ def __getitem__(self, idx):
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32")
if self.label_ratio is not None:
if self.label_ratio is not False:
return (img, np.array([label, self.label_ratio]))
else:
return (img, label)
Expand Down

0 comments on commit 3a28ee2

Please sign in to comment.