-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets.py
146 lines (131 loc) · 5.96 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# code adapted from https://github.com/pluskid/fitting-random-labels
from torchvision import datasets, transforms
import torch
import numpy as np
import time
class MNISTRandomLabels(datasets.MNIST):
"""MNIST dataset, with support for randomly corrupt labels.
Params
------
corrupt_prob: float
Default 0.0. The probability of a label being replaced with
random label.
num_classes: int
Default 10. The number of classes in the dataset.
"""
def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
super(MNISTRandomLabels, self).__init__(**kwargs)
self.n_classes = num_classes
if corrupt_prob > 0:
self.corrupt_labels(corrupt_prob)
def corrupt_labels(self, corrupt_prob):
labels = np.array(self.targets if self.train else self.test_labels)
np.random.seed(int(time.time()))
mask = np.random.rand(len(labels)) <= corrupt_prob
rnd_labels = np.random.choice(self.n_classes, mask.sum())
labels[mask] = rnd_labels
labels = [int(x) for x in labels]
if self.train:
self.targets = labels
else:
self.targets = labels
class CIFAR100RandomLabels(datasets.CIFAR100):
"""cifar-100 dataset, with support for randomly corrupt labels.
Params
------
corrupt_prob: float
Default 0.0. The probability of a label being replaced with
random label.
num_classes: int
Default 100. The number of classes in the dataset.
"""
def __init__(self, corrupt_prob=0.0, num_classes=100, **kwargs):
super(CIFAR100RandomLabels, self).__init__(**kwargs)
self.n_classes = num_classes
if corrupt_prob > 0:
self.corrupt_labels(corrupt_prob)
def corrupt_labels(self, corrupt_prob):
labels = np.array(self.targets if self.train else self.test_labels)
np.random.seed(int(time.time()))
mask = np.random.rand(len(labels)) <= corrupt_prob
rnd_labels = np.random.choice(self.n_classes, mask.sum())
labels[mask] = rnd_labels
labels = [int(x) for x in labels]
if self.train:
self.targets = labels
else:
self.targets = labels
class CIFAR10RandomLabels(datasets.CIFAR10):
"""cifar10 dataset, with support for randomly corrupt labels.
Params
------
corrupt_prob: float
Default 0.0. The probability of a label being replaced with
random label.
num_classes: int
Default 10. The number of classes in the dataset.
"""
def __init__(self, corrupt_prob=0.0, num_classes=10, **kwargs):
super(CIFAR10RandomLabels, self).__init__(**kwargs)
self.n_classes = num_classes
if corrupt_prob > 0:
self.corrupt_labels(corrupt_prob)
def corrupt_labels(self, corrupt_prob):
labels = np.array(self.targets if self.train else self.test_labels)
np.random.seed(int(time.time()))
mask = np.random.rand(len(labels)) <= corrupt_prob
rnd_labels = np.random.choice(self.n_classes, mask.sum())
labels[mask] = rnd_labels
labels = [int(x) for x in labels]
if self.train:
self.targets = labels
else:
self.targets = labels
def get_data_loader(name, batch_size, num_samples=None, corrupt_prob = 0):
""" get test and train dataloaders
Params
-----
name: the name of the dataset. Choices are: 'cifar10', 'mnist', and 'cifar100'
batch_size: int
The size of the batch.
num_samples: int
Default None. The number of training samples to use.
corrupt_prob: float between 0 and 1
Default 0. The probability of a label being random.
"""
if name == 'cifar10':
train_dataset = CIFAR10RandomLabels(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True, corrupt_prob = corrupt_prob)
test_dataset = datasets.CIFAR10(root='./data',
train=False, download=True,
transform=transforms.ToTensor())
elif name == 'mnist':
train_dataset = MNISTRandomLabels(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True, corrupt_prob = corrupt_prob)
test_dataset = datasets.MNIST(root='./data',
train=False, download=True,
transform=transforms.ToTensor())
elif name == 'cifar100':
train_dataset = CIFAR100RandomLabels(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True, corrupt_prob = corrupt_prob)
test_dataset = datasets.CIFAR100(root='./data',
train=False, download=True,
transform=transforms.ToTensor())
if num_samples == None:
num_samples = len(train_dataset)
# in case we want to train on part of the training set instead of all
my_train_dataset, rest_train_dataset = torch.utils.data.random_split(dataset=train_dataset, lengths=[num_samples,len(train_dataset)-num_samples])
# Data loader
my_train_loader = torch.utils.data.DataLoader(dataset=my_train_dataset, num_workers = 11,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, num_workers = 11,
batch_size=batch_size,
shuffle=True)
return my_train_loader, test_loader, my_train_dataset