-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
63 lines (52 loc) · 1.64 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Union
import torch
from torch_geometric.loader import DataLoader
def reset_metrics(metrics:dict):
# reset metric for evaluation
for _, metric in metrics.items():
metric.reset()
def train(
model:torch.nn.Module,
loader:DataLoader,
optimizer:torch.optim.Optimizer,
loss_fn:callable,
metrics:dict,
device:torch.cuda.device
) -> float:
#trainning process
model.train()
reset_metrics(metrics=metrics)
for data in loader:
data = data.to(device)
optimizer.zero_grad()
# add axuliary loss
out, aux_loss = model(data)
loss = loss_fn(out, data.y) + aux_loss if aux_loss is not None else loss_fn(out, data.y)
loss.backward()
optimizer.step()
# detach loss for document
loss = loss.detach()
metrics["loss"].update(loss)
loss = metrics["loss"].compute()
return float(loss)
@torch.no_grad()
def test(
model:torch.nn.Module,
loader:DataLoader,
loss_fn:callable,
metrics:dict,
device:torch.cuda.device
) -> Union[float, float]:
#testing process
model.eval()
reset_metrics(metrics=metrics)
for data in loader:
data = data.to(device)
out, aux_loss = model(data)
loss = loss_fn(out, data.y) + aux_loss if aux_loss is not None else loss_fn(out, data.y)
# categorical accuracy if its classification problem
metrics["acc"].update(out, data.y)
metrics["loss"].update(loss)
acc = metrics["acc"].compute()
loss = metrics["loss"].compute()
return float(acc), float(loss)