forked from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_datamodule.py
130 lines (107 loc) · 4.42 KB
/
mnist_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Any, Dict, Optional, Tuple
import torch
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
class MNISTDataModule(LightningDataModule):
"""Example of LightningDataModule for MNIST dataset.
A DataModule implements 6 key methods:
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
"""
def __init__(
self,
data_dir: str = "data/",
train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)
# data transformations
self.transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None
@property
def num_classes(self):
return 10
def prepare_data(self):
"""Download data if needed.
Do not use it to assign state (self.x = y).
"""
MNIST(self.hparams.data_dir, train=True, download=True)
MNIST(self.hparams.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
"""
# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)
dataset = ConcatDataset(datasets=[trainset, testset])
self.data_train, self.data_val, self.data_test = random_split(
dataset=dataset,
lengths=self.hparams.train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
def train_dataloader(self):
return DataLoader(
dataset=self.data_train,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
dataset=self.data_test,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
pass
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
pass
if __name__ == "__main__":
_ = MNISTDataModule()