-
Notifications
You must be signed in to change notification settings - Fork 657
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #724 from ir2718/dataset
[`feat`] Addition of popular image retrieval benchmark datasets
- Loading branch information
Showing
14 changed files
with
784 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Datasets | ||
|
||
Datasets classes give you a way to automatically download a dataset and transform it into a PyTorch dataset. | ||
|
||
All implemented datasets have disjoint train-test splits, ideal for benchmarking on image retrieval and one-shot/few-shot classification tasks. | ||
|
||
## BaseDataset | ||
|
||
All dataset classes extend this class and therefore inherit its ```__init__``` parameters. | ||
|
||
```python | ||
datasets.base_dataset.BaseDataset( | ||
root, | ||
split="train+test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
``` | ||
|
||
**Parameters**: | ||
|
||
* **root**: The path where the dataset files are saved. | ||
* **split**: A string that determines which split of the dataset is loaded. | ||
* **transform**: A `torchvision.transforms` object which will be used on the input images. | ||
* **target_transform**: A `torchvision.transforms` object which will be used on the labels. | ||
* **download**: Whether to download the dataset or not. Setting this as False, but not having the dataset on the disk will raise a ValueError. | ||
|
||
**Required Implementations**: | ||
```python | ||
@abstractmethod | ||
def download_and_remove(): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def generate_split(): | ||
raise NotImplementedError | ||
``` | ||
|
||
## CUB-200-2011 | ||
|
||
```python | ||
datasets.cub.CUB(*args, **kwargs) | ||
``` | ||
|
||
**Defined splits**: | ||
|
||
- `train` - Consists of 5864 examples, taken from classes 1 to 100. | ||
- `test` - Consists of 5924 examples, taken from classes 101 to 200. | ||
- `train+test` - Consists 11788 of examples, taken from all classes. | ||
|
||
**Loading different dataset splits** | ||
```python | ||
train_dataset = CUB(root="data", | ||
split="train", | ||
transform=None, | ||
target_transform=None, | ||
download=True | ||
) | ||
# No need to download the dataset after it is already downladed | ||
test_dataset = CUB(root="data", | ||
split="test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
train_and_test_dataset = CUB(root="data", | ||
split="train+test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
``` | ||
|
||
## Cars196 | ||
|
||
```python | ||
datasets.cars196.Cars196(*args, **kwargs) | ||
``` | ||
|
||
**Defined splits**: | ||
|
||
- `train` - Consists of 8054 examples, taken from classes 1 to 99. | ||
- `test` - Consists of 8131 examples, taken from classes 99 to 197. | ||
- `train+test` - Consists of 16185 examples, taken from all classes. | ||
|
||
**Loading different dataset splits** | ||
```python | ||
train_dataset = Cars196(root="data", | ||
split="train", | ||
transform=None, | ||
target_transform=None, | ||
download=True | ||
) | ||
# No need to download the dataset after it is already downladed | ||
test_dataset = Cars196(root="data", | ||
split="test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
train_and_test_dataset = Cars196(root="data", | ||
split="train+test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
``` | ||
|
||
## INaturalist2018 | ||
|
||
```python | ||
datasets.inaturalist2018.INaturalist2018(*args, **kwargs) | ||
``` | ||
|
||
**Defined splits**: | ||
|
||
- `train` - Consists of 325 846 examples. | ||
- `test` - Consists of 136 093 examples. | ||
- `train+test` - Consists of 461 939 examples. | ||
|
||
**Loading different dataset splits** | ||
```python | ||
# The download takes a while - the dataset is very large | ||
train_dataset = INaturalist2018(root="data", | ||
split="train", | ||
transform=None, | ||
target_transform=None, | ||
download=True | ||
) | ||
# No need to download the dataset after it is already downladed | ||
test_dataset = INaturalist2018(root="data", | ||
split="test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
train_and_test_dataset = INaturalist2018(root="data", | ||
split="train+test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
``` | ||
|
||
## StanfordOnlineProducts | ||
|
||
```python | ||
datasets.sop.StanfordOnlineProducts(*args, **kwargs) | ||
``` | ||
|
||
**Defined splits**: | ||
|
||
- `train` - Consists of 59551 examples. | ||
- `test` - Consists of 60502 examples. | ||
- `train+test` - Consists of 120 053 examples. | ||
|
||
**Loading different dataset splits** | ||
```python | ||
# The download takes a while - the dataset is very large | ||
train_dataset = StanfordOnlineProducts(root="data", | ||
split="train", | ||
transform=None, | ||
target_transform=None, | ||
download=True | ||
) | ||
# No need to download the dataset after it is already downladed | ||
test_dataset = StanfordOnlineProducts(root="data", | ||
split="test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
train_and_test_dataset = StanfordOnlineProducts(root="data", | ||
split="train+test", | ||
transform=None, | ||
target_transform=None, | ||
download=False | ||
) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# How to write custom datasets | ||
|
||
1. Subclass the ```datasets.base_dataset.BaseDataset``` class | ||
2. Add implementations for abstract methods from the base class: | ||
- ```download_and_remove()``` | ||
- ```generate_split()``` | ||
|
||
|
||
```python | ||
from pytorch_metric_learning.datasets.base_dataset import BaseDataset | ||
|
||
class MyDataset(BaseDataset): | ||
|
||
def __init__(self, my_parameter, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.my_parameter = self.my_parameter | ||
|
||
def download_and_remove(self): | ||
# Downloads the dataset files needed | ||
# | ||
# If you're using a dataset that you've already downloaded elsewhere, | ||
# just use an empty implementation | ||
pass | ||
|
||
def generate_split(self): | ||
# Creates a list of image paths, and saves them into self.paths | ||
# Creates a list of labels for the images, and saves them into self.labels | ||
# | ||
# The default training splits that need to be covered are `train`, `test`, and `train+test` | ||
# If you need a different split setup, override `get_available_splits(self)` to return | ||
# the split names you want | ||
pass | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
from abc import ABC, abstractmethod | ||
|
||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class BaseDataset(ABC, Dataset): | ||
|
||
def __init__( | ||
self, | ||
root, | ||
split="train+test", | ||
transform=None, | ||
target_transform=None, | ||
download=False, | ||
): | ||
self.root = root | ||
|
||
if download: | ||
if not os.path.isdir(self.root): | ||
os.makedirs(self.root, exist_ok=False) | ||
self.download_and_remove() | ||
elif os.listdir(self.root) == []: | ||
self.download_and_remove() | ||
else: | ||
# The given directory does not exist so the user should be aware of downloading it | ||
# Otherwise proceed as usual | ||
if not os.path.isdir(self.root): | ||
raise ValueError( | ||
"The given path does not exist. " | ||
"You should probably initialize the dataset with download=True." | ||
) | ||
|
||
self.transform = transform | ||
self.target_transform = target_transform | ||
|
||
if split not in self.get_available_splits(): | ||
raise ValueError( | ||
f"Supported splits are: {', '.join(self.get_available_splits())}" | ||
) | ||
|
||
self.split = split | ||
|
||
self.generate_split() | ||
|
||
@abstractmethod | ||
def generate_split(): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def download_and_remove(): | ||
raise NotImplementedError | ||
|
||
def get_available_splits(self): | ||
return ["train", "test", "train+test"] | ||
|
||
def __len__(self): | ||
return len(self.labels) | ||
|
||
def __getitem__(self, idx): | ||
img = Image.open(self.paths[idx]) | ||
label = self.labels[idx] | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
if self.target_transform is not None: | ||
label = self.target_transform(label) | ||
|
||
return (img, label) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import zipfile | ||
|
||
from ..datasets.base_dataset import BaseDataset | ||
from ..utils.common_functions import _urlretrieve | ||
|
||
|
||
class Cars196(BaseDataset): | ||
|
||
DOWNLOAD_URL = "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder" | ||
|
||
def generate_split(self): | ||
# Training set is first 99 classes, test is other classes | ||
if self.split == "train": | ||
classes = set(range(1, 99)) | ||
elif self.split == "test": | ||
classes = set(range(99, 197)) | ||
else: | ||
classes = set(range(1, 197)) | ||
|
||
with open(os.path.join(self.root, "names.csv"), "r") as f: | ||
names = [x.strip() for x in f.readlines()] | ||
|
||
paths_train, labels_train = self._load_csv( | ||
os.path.join(self.root, "anno_train.csv"), names, split="train" | ||
) | ||
paths_test, labels_test = self._load_csv( | ||
os.path.join(self.root, "anno_test.csv"), names, split="test" | ||
) | ||
paths = paths_train + paths_test | ||
labels = labels_train + labels_test | ||
|
||
self.paths, self.labels = [], [] | ||
for p, l in zip(paths, labels): | ||
if l in classes: | ||
self.paths.append(p) | ||
self.labels.append(l) | ||
|
||
def _load_csv(self, path, names, split): | ||
all_paths, all_labels = [], [] | ||
with open(path, "r") as f: | ||
for l in f: | ||
path_annos = l.split(",") | ||
curr_path = path_annos[0] | ||
curr_label = path_annos[-1] | ||
all_paths.append( | ||
os.path.join( | ||
self.root, | ||
"car_data", | ||
"car_data", | ||
split, | ||
names[int(curr_label) - 1].replace("/", "-"), | ||
curr_path, | ||
) | ||
) | ||
all_labels.append(int(curr_label)) | ||
return all_paths, all_labels | ||
|
||
def download_and_remove(self): | ||
os.makedirs(self.root, exist_ok=True) | ||
download_folder_path = os.path.join( | ||
self.root, Cars196.DOWNLOAD_URL.split("/")[-1] | ||
) | ||
_urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path) | ||
with zipfile.ZipFile(download_folder_path, "r") as zip_ref: | ||
zip_ref.extractall(self.root) | ||
os.remove(download_folder_path) |
Oops, something went wrong.