-
Notifications
You must be signed in to change notification settings - Fork 19
/
main.py
66 lines (52 loc) · 1.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import config
from utils import (
get_model_instance_segmentation,
collate_fn,
get_transform,
myOwnDataset,
)
print("Torch version:", torch.__version__)
# create own Dataset
my_dataset = myOwnDataset(
root=config.train_data_dir, annotation=config.train_coco, transforms=get_transform()
)
# own DataLoader
data_loader = torch.utils.data.DataLoader(
my_dataset,
batch_size=config.train_batch_size,
shuffle=config.train_shuffle_dl,
num_workers=config.num_workers_dl,
collate_fn=collate_fn,
)
# select device (whether GPU or CPU)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# DataLoader is iterable over Dataset
for imgs, annotations in data_loader:
imgs = list(img.to(device) for img in imgs)
annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
print(annotations)
model = get_model_instance_segmentation(config.num_classes)
# move model to the right device
model.to(device)
# parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
params, lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay
)
len_dataloader = len(data_loader)
# Training
for epoch in range(config.num_epochs):
print(f"Epoch: {epoch}/{config.num_epochs}")
model.train()
i = 0
for imgs, annotations in data_loader:
i += 1
imgs = list(img.to(device) for img in imgs)
annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
loss_dict = model(imgs, annotations)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
print(f"Iteration: {i}/{len_dataloader}, Loss: {losses}")