-
Notifications
You must be signed in to change notification settings - Fork 19
/
train.py
79 lines (67 loc) · 2.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
import torch.nn as nn
import utils.bio as bio
from tqdm import tqdm
from extract import extract
from utils import utils
from test import do_eval
def train(args,
epoch,
model,
trn_loader,
dev_loaders,
summarizer,
optimizer,
scheduler):
total_pred_loss, total_arg_loss, trn_results = 0, 0, None
epoch_steps = int(args.total_steps / args.epochs)
iterator = tqdm(enumerate(trn_loader), desc='steps', total=epoch_steps)
for step, batch in iterator:
batch = map(lambda x: x.to(args.device), batch)
token_ids, att_mask, single_pred_label, single_arg_label, all_pred_label = batch
pred_mask = bio.get_pred_mask(single_pred_label)
model.train()
model.zero_grad()
# feed to predicate model
batch_loss, pred_loss, arg_loss = model(
input_ids=token_ids,
attention_mask=att_mask,
predicate_mask=pred_mask,
total_pred_labels=all_pred_label,
arg_labels=single_arg_label)
# get performance on this batch
total_pred_loss += pred_loss.item()
total_arg_loss += arg_loss.item()
batch_loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
trn_results = [total_pred_loss / (step + 1), total_arg_loss / (step + 1)]
if step > epoch_steps:
break
# interim evaluation
if step % 1000 == 0 and step != 0:
dev_iter = zip(args.dev_data_path, args.dev_gold_path, dev_loaders)
dev_results = list()
total_sum = 0
for dev_input, dev_gold, dev_loader in dev_iter:
dev_name = dev_input.split('/')[-1].replace('.pkl', '')
output_path = os.path.join(args.save_path, f'epoch{epoch}_dev/step{step}/{dev_name}')
extract(args, model, dev_loader, output_path)
dev_result = do_eval(output_path, dev_gold)
utils.print_results(f"EPOCH{epoch} STEP{step} EVAL",
dev_result, ["F1 ", "PREC", "REC ", "AUC "])
total_sum += dev_result[0] + dev_result[-1]
dev_result.append(dev_result[0] + dev_result[-1])
dev_results += dev_result
summarizer.save_results([step] + trn_results + dev_results + [total_sum])
model_name = utils.set_model_name(total_sum, epoch, step)
torch.save(model.state_dict(), os.path.join(args.save_path, model_name))
if step % args.summary_step == 0 and step != 0:
utils.print_results(f"EPOCH{epoch} STEP{step} TRAIN",
trn_results, ["PRED LOSS", "ARG LOSS "])
# end epoch summary
utils.print_results(f"EPOCH{epoch} TRAIN",
trn_results, ["PRED LOSS", "ARG LOSS "])
return trn_results