-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_preparation.py
252 lines (207 loc) · 8.84 KB
/
dataset_preparation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""
Prepares two kinds of datasets:
- An unlabelled dataset for self-supervised training
- Labelled datasets for classifier training
The functions in this file are:
- load_datasets: Load two dataset, one large unlabelled one and one with known tidal features
- find_mad_std: Find the median absolute deviation of each channel to use for normalisationg and augmentations
- normalisation: Normalises images
- split_datasets: Split datasets into training, validation, and testing for classifier training
- label_positive and label_negative: Assigns labels to images
- prepare_datasets: Uses the above functions to assemble the datasets
"""
import os
import astropy
from astropy.stats import mad_std
import tensorflow as tf
import numpy as np
def load_datasets(dataset_path, splits):
"""
Load two dataset, one large unlabelled one and one with known tidal features
Parameters
----------
dataset_path: str
Path location of datasets
splits: List of str
Name of datasets to load from dataset_path
Format: ['Unlabelled_ds', 'Tidal_ds']
Returns
-------
Unlabelled_dset: Tensorflow Dataset
Large unlabelled dataset
Tidal_dset: Tensorflow Dataset
Dataset with known tidal features
"""
# Load unlabelled large (40,000 gals) dataset
path =os.path.join(dataset_path, splits[0])
Unlabelled_dset = tf.data.Dataset.load(path,compression=None)
# Load small tidal feature (380 gals) dataset
path =os.path.join(dataset_path, splits[1])
Tidal_dset = tf.data.Dataset.load(path,compression=None)
return Unlabelled_dset, Tidal_dset
def find_mad_std(dataset,bands):
"""
Find the median absolute deviation of each channel to use for normalisationg and augmentations
Parameters
----------
dataset: Tensorflow Dataset
Dataset to be sampled from
bands: list of str
List of channels of the image e.g. ['g', 'r', 'i']
Returns
-------
scaling: List of floats of length len(bands)
Scaling factor by band/channel
"""
cutouts = []
# Append the first 1000 galaxies from the dataset to an array
for (batch, entry) in enumerate(dataset.take(1000)):
# Only select the image
cutouts.append(entry['image'])
cutouts = np.stack(cutouts)
# add median absolute deviation for each band to an array
scaling = []
for i, b in enumerate(bands):
sigma = mad_std(cutouts[..., i].flatten())
scaling.append(sigma)
return scaling
def normalisation(example,scale):
"""
Normalises images
example: element in dataset
scale: array of median absolute deviation for each band
Parameters
----------
example: Dataset item
Element in dataset
scale: List of floats of length len(bands)
Scaling factor by band/channel
Returns
-------
img: numpy array
normalised image
"""
# Get only the image from the dataset element
img = example['image']
img = tf.math.asinh(img / tf.constant(scale, dtype=tf.float32) / 3.)
# We return the normalised images
return img
def split_datasets(Unlabelled_dset, Tidal_dset, data_sizes):
"""
Split datasets into training, validation, and testing for classifier training
Parameters
----------
Unlabelled_dset: Tensorflow Dataset
Large unlabelled dataset
Tidal_dset: Tensorflow Dataset
Dataset with known tidal features
data_sizes: List of ints of length 3
Dataset sizes per class to use for training, validation, and testing
Format: [train_size, val_size, test_size]
For class = 2 (tidal, no_tidal), train_size should be half the total
training set size.
Returns
-------
No_tidal_dsets: List of Tensorflow Datasets
Negative examples to use for training, validation, and testing
Format: [train_dset, val_dset, test_dset]
Tidal_dsets: List of Tensorflow Datasets
Positive examples to use for training, validation, and testing
Format: [train_dset, val_dset, test_dset]
"""
# Unpack data_sizes
train_size, val_size, test_size = data_sizes
# Split the dataset without tidal features
no_tidal_train_dset = Unlabelled_dset.take(train_size)
no_tidal_val_dset = Unlabelled_dset.skip(train_size).take(val_size)
no_tidal_test_dset = Unlabelled_dset.skip(train_size + val_size).take(test_size)
# Split the dataset with tidal features
tidal_train_dset = Tidal_dset.take(train_size)
tidal_val_dset = Tidal_dset.skip(train_size).take(val_size)
tidal_test_dset = Tidal_dset.skip(train_size + val_size).take(test_size)
No_tidal_dsets = [no_tidal_train_dset,no_tidal_val_dset,no_tidal_test_dset]
Tidal_dsets = [tidal_train_dset,tidal_val_dset,tidal_test_dset]
return [No_tidal_dsets, Tidal_dsets]
# Functions to assign positive (tidal) and negative (non-tidal) labels
def label_positive(image):
return image, 1
def label_negative(image):
return image, 0
def prepare_datasets(shuffle_buffer, bands, labelled_batch_size,
unlabelled_batch_size, dataset_path, splits, data_sizes
):
"""
Prepares datasets for training, validation, and testing using the functions above
Parameters
----------
shuffle_buffer: int
Buffer size to use when shuffling datasets
bands: list of str
List of channels of the image e.g. ['g', 'r', 'i']
labelled_batch_size: int
Batch size to use when training supervised classifier
unlabelled_batch_size: int
Batch size to use when training self-supervised encoder
dataset_path: str
Path location of datasets
splits: List of str
Name of datasets to load from dataset_path
Format: ['No_tidal', 'Tidal_ds_no_dup']
data_sizes: List of ints of length 3
Dataset sizes per class to use for training, validation, and testing
Format: [train_size, val_size, test_size]
For class = 2 (tidal, no_tidal), train_size should be half the total
training set size.
Returns
-------
unlabelled_train_dataset: Batched Tensorflow Dataset
Dataset to use for training the self-supervised encoder
labelled_train_dataset: Batched Tensorflow Dataset
Dataset to use for training the supervised classifier
val_dataset: Batched Tensorflow Dataset
Dataset to use for validation of the supervised classifier
test_dataset: Batched Tensorflow Dataset
Dataset to use for testing of the supervised classifier
scale: List of floats of length len(bands)
Scaling factor by band/channel
"""
train_size, val_size, test_size = data_sizes
# Load datasets and find scaling factor
Unlabelled_dset, Tidal_dset = load_datasets(dataset_path, splits)
scale = find_mad_std(Unlabelled_dset,bands = bands)
# Normalise datasets
norm_unlabelled_dset = Unlabelled_dset.map(lambda x: normalisation(x, scale))
norm_Tidal_dset = Tidal_dset.map(lambda x: normalisation(x, scale))
#Split datasets
No_tidal_dsets, Tidal_dsets = split_datasets(norm_unlabelled_dset, norm_Tidal_dset, data_sizes)
no_tidal_train_dset, no_tidal_val_dset, no_tidal_test_dset = No_tidal_dsets
tidal_train_dset, tidal_val_dset, tidal_test_dset = Tidal_dsets
# Assign labels to datasets
positive_dset_train = tidal_train_dset.map(label_positive)
negative_dset_train = no_tidal_train_dset.map(label_negative)
positive_dset_val = tidal_val_dset.map(label_positive)
negative_dset_val = no_tidal_val_dset.map(label_negative)
positive_dset_test = tidal_test_dset.map(label_positive)
negative_dset_test = no_tidal_test_dset.map(label_negative)
# Combine positive and negative datasets
labelled_dset_train = positive_dset_train.concatenate(negative_dset_train)
labelled_dset_val = positive_dset_val.concatenate(negative_dset_val)
labelled_dset_test = positive_dset_test.concatenate(negative_dset_test)
# Batch and shuffle datasets
labelled_train_dataset = (labelled_dset_train
.shuffle(buffer_size=shuffle_buffer)
.batch(labelled_batch_size, drop_remainder=True)
)
val_dataset = (labelled_dset_val
.batch(val_size*2)
.shuffle(buffer_size=shuffle_buffer)
)
test_dataset = (labelled_dset_test
.batch(test_size*2)
)
unlabelled_train_dataset = (norm_Tidal_dset
.concatenate(norm_unlabelled_dset)
.shuffle(buffer_size=shuffle_buffer)
.batch(unlabelled_batch_size, drop_remainder=True)
)
return unlabelled_train_dataset, labelled_train_dataset, val_dataset, test_dataset, scale