forked from arimousa/DDAD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
61 lines (55 loc) · 2.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import os
import torch.nn as nn
from forward_process import *
from dataset import *
from dataset import *
from test import *
from loss import *
from sample import *
def trainer(model, category, config):
'''
Training the UNet model
:param model: the UNet model
:param category: the category of the dataset
'''
optimizer = torch.optim.Adam(
model.parameters(), lr=config.model.learning_rate, weight_decay=config.model.weight_decay
)
train_dataset = Dataset_maker(
root= config.data.data_dir,
category=category,
config = config,
is_train=True,
)
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=True,
num_workers=config.model.num_workers,
drop_last=True,
)
if not os.path.exists('checkpoints'):
os.mkdir('checkpoints')
if not os.path.exists(config.model.checkpoint_dir):
os.mkdir(config.model.checkpoint_dir)
for epoch in range(config.model.epochs):
for step, batch in enumerate(trainloader):
t = torch.randint(0, config.model.trajectory_steps, (batch[0].shape[0],), device=config.model.device).long()
optimizer.zero_grad()
loss = get_loss(model, batch[0], t, config)
loss.backward()
optimizer.step()
if epoch % 50 == 0 and step == 0:
print(f"Epoch {epoch} | Loss: {loss.item()}")
if epoch %250 == 0 and step ==0:
if config.model.save_model:
model_save_dir = os.path.join(os.getcwd(), config.model.checkpoint_dir, category)
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
torch.save(model.state_dict(), os.path.join(model_save_dir, str(epoch)))
if config.model.save_model:
model_save_dir = os.path.join(os.getcwd(), config.model.checkpoint_dir, category)
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
torch.save(model.state_dict(), os.path.join(model_save_dir, str(config.model.epochs)))