Skip to content

Commit

Permalink
Merge pull request #724 from ir2718/dataset
Browse files Browse the repository at this point in the history
[`feat`] Addition of popular image retrieval benchmark datasets
  • Loading branch information
Kevin Musgrave authored Dec 11, 2024
2 parents b403b8e + 539c93a commit c82c626
Show file tree
Hide file tree
Showing 14 changed files with 784 additions and 0 deletions.
180 changes: 180 additions & 0 deletions docs/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Datasets

Datasets classes give you a way to automatically download a dataset and transform it into a PyTorch dataset.

All implemented datasets have disjoint train-test splits, ideal for benchmarking on image retrieval and one-shot/few-shot classification tasks.

## BaseDataset

All dataset classes extend this class and therefore inherit its ```__init__``` parameters.

```python
datasets.base_dataset.BaseDataset(
root,
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

**Parameters**:

* **root**: The path where the dataset files are saved.
* **split**: A string that determines which split of the dataset is loaded.
* **transform**: A `torchvision.transforms` object which will be used on the input images.
* **target_transform**: A `torchvision.transforms` object which will be used on the labels.
* **download**: Whether to download the dataset or not. Setting this as False, but not having the dataset on the disk will raise a ValueError.

**Required Implementations**:
```python
@abstractmethod
def download_and_remove():
raise NotImplementedError

@abstractmethod
def generate_split():
raise NotImplementedError
```

## CUB-200-2011

```python
datasets.cub.CUB(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 5864 examples, taken from classes 1 to 100.
- `test` - Consists of 5924 examples, taken from classes 101 to 200.
- `train+test` - Consists 11788 of examples, taken from all classes.

**Loading different dataset splits**
```python
train_dataset = CUB(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = CUB(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = CUB(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

## Cars196

```python
datasets.cars196.Cars196(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 8054 examples, taken from classes 1 to 99.
- `test` - Consists of 8131 examples, taken from classes 99 to 197.
- `train+test` - Consists of 16185 examples, taken from all classes.

**Loading different dataset splits**
```python
train_dataset = Cars196(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = Cars196(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = Cars196(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

## INaturalist2018

```python
datasets.inaturalist2018.INaturalist2018(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 325 846 examples.
- `test` - Consists of 136 093 examples.
- `train+test` - Consists of 461 939 examples.

**Loading different dataset splits**
```python
# The download takes a while - the dataset is very large
train_dataset = INaturalist2018(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = INaturalist2018(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = INaturalist2018(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

## StanfordOnlineProducts

```python
datasets.sop.StanfordOnlineProducts(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 59551 examples.
- `test` - Consists of 60502 examples.
- `train+test` - Consists of 120 053 examples.

**Loading different dataset splits**
```python
# The download takes a while - the dataset is very large
train_dataset = StanfordOnlineProducts(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = StanfordOnlineProducts(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = StanfordOnlineProducts(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```
34 changes: 34 additions & 0 deletions docs/extend/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# How to write custom datasets

1. Subclass the ```datasets.base_dataset.BaseDataset``` class
2. Add implementations for abstract methods from the base class:
- ```download_and_remove()```
- ```generate_split()```


```python
from pytorch_metric_learning.datasets.base_dataset import BaseDataset

class MyDataset(BaseDataset):

def __init__(self, my_parameter, *args, **kwargs):
super().__init__(*args, **kwargs)
self.my_parameter = self.my_parameter

def download_and_remove(self):
# Downloads the dataset files needed
#
# If you're using a dataset that you've already downloaded elsewhere,
# just use an empty implementation
pass

def generate_split(self):
# Creates a list of image paths, and saves them into self.paths
# Creates a list of labels for the images, and saves them into self.labels
#
# The default training splits that need to be covered are `train`, `test`, and `train+test`
# If you need a different split setup, override `get_available_splits(self)` to return
# the split names you want
pass

```
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
site_name: PyTorch Metric Learning
nav:
- Home: index.md
- Datasets: datasets.md
- Distances: distances.md
- Losses: losses.md
- Miners: miners.md
Expand All @@ -16,6 +17,7 @@ nav:
- Common Functions: common_functions.md
- Distributed: distributed.md
- How to extend this library:
- Custom datasets: extend/datasets.md
- Custom losses: extend/losses.md
- Custom miners: extend/miners.md
- Frequently Asked Questions: faq.md
Expand Down
71 changes: 71 additions & 0 deletions src/pytorch_metric_learning/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
from abc import ABC, abstractmethod

from PIL import Image
from torch.utils.data import Dataset


class BaseDataset(ABC, Dataset):

def __init__(
self,
root,
split="train+test",
transform=None,
target_transform=None,
download=False,
):
self.root = root

if download:
if not os.path.isdir(self.root):
os.makedirs(self.root, exist_ok=False)
self.download_and_remove()
elif os.listdir(self.root) == []:
self.download_and_remove()
else:
# The given directory does not exist so the user should be aware of downloading it
# Otherwise proceed as usual
if not os.path.isdir(self.root):
raise ValueError(
"The given path does not exist. "
"You should probably initialize the dataset with download=True."
)

self.transform = transform
self.target_transform = target_transform

if split not in self.get_available_splits():
raise ValueError(
f"Supported splits are: {', '.join(self.get_available_splits())}"
)

self.split = split

self.generate_split()

@abstractmethod
def generate_split():
raise NotImplementedError

@abstractmethod
def download_and_remove():
raise NotImplementedError

def get_available_splits(self):
return ["train", "test", "train+test"]

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):
img = Image.open(self.paths[idx])
label = self.labels[idx]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
label = self.target_transform(label)

return (img, label)
67 changes: 67 additions & 0 deletions src/pytorch_metric_learning/datasets/cars196.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import zipfile

from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve


class Cars196(BaseDataset):

DOWNLOAD_URL = "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder"

def generate_split(self):
# Training set is first 99 classes, test is other classes
if self.split == "train":
classes = set(range(1, 99))
elif self.split == "test":
classes = set(range(99, 197))
else:
classes = set(range(1, 197))

with open(os.path.join(self.root, "names.csv"), "r") as f:
names = [x.strip() for x in f.readlines()]

paths_train, labels_train = self._load_csv(
os.path.join(self.root, "anno_train.csv"), names, split="train"
)
paths_test, labels_test = self._load_csv(
os.path.join(self.root, "anno_test.csv"), names, split="test"
)
paths = paths_train + paths_test
labels = labels_train + labels_test

self.paths, self.labels = [], []
for p, l in zip(paths, labels):
if l in classes:
self.paths.append(p)
self.labels.append(l)

def _load_csv(self, path, names, split):
all_paths, all_labels = [], []
with open(path, "r") as f:
for l in f:
path_annos = l.split(",")
curr_path = path_annos[0]
curr_label = path_annos[-1]
all_paths.append(
os.path.join(
self.root,
"car_data",
"car_data",
split,
names[int(curr_label) - 1].replace("/", "-"),
curr_path,
)
)
all_labels.append(int(curr_label))
return all_paths, all_labels

def download_and_remove(self):
os.makedirs(self.root, exist_ok=True)
download_folder_path = os.path.join(
self.root, Cars196.DOWNLOAD_URL.split("/")[-1]
)
_urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path)
with zipfile.ZipFile(download_folder_path, "r") as zip_ref:
zip_ref.extractall(self.root)
os.remove(download_folder_path)
Loading

0 comments on commit c82c626

Please sign in to comment.