forked from tomsherborne/nlu_cw2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualize.py
132 lines (107 loc) · 5.14 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import argparse
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
from torch.serialization import default_restore_location
from seq2seq import models, utils
from seq2seq.data.dictionary import Dictionary
from seq2seq.data.dataset import Seq2SeqDataset, BatchSampler
def get_args():
""" Defines training-specific hyper-parameters. """
parser = argparse.ArgumentParser('Sequence to Sequence Model')
# Add data arguments
parser.add_argument('--data', default='europarl_prepared',
help='path to data directory')
parser.add_argument('--source-lang', default='de', help='source language')
parser.add_argument('--target-lang', default='en', help='target language')
parser.add_argument('--checkpoint-path', default='checkpoints/checkpoint_best.pt',
help='path to the model file')
parser.add_argument('--vis-dir', default='visualizations',
help='path to the model file')
parser.add_argument('--cuda', default = False, help = 'Use a GPU')
return parser.parse_args()
def main(args):
""" Main function. Visualises attention weight arrays as nifty heat-maps. """
torch.manual_seed(42)
state_dict = torch.load(
args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
saved_args = vars(state_dict['args'])
for k in saved_args:
if type(saved_args[k]) == str and '/home/lvyajie/nlu_cw2/' in saved_args[k]:
saved_args[k] = saved_args[k].replace('/home/lvyajie/nlu_cw2/', '')
args = argparse.Namespace(**{**vars(args), **saved_args})
utils.init_logging(args)
# Load dictionaries
src_dict = Dictionary.load(os.path.join(
args.data, 'dict.{:s}'.format(args.source_lang)))
print('Loaded a source dictionary ({:s}) with {:d} words'.format(
args.source_lang, len(src_dict)))
tgt_dict = Dictionary.load(os.path.join(
args.data, 'dict.{:s}'.format(args.target_lang)))
print('Loaded a target dictionary ({:s}) with {:d} words'.format(
args.target_lang, len(tgt_dict)))
# Load dataset
test_dataset = Seq2SeqDataset(
src_file=os.path.join(args.data, 'test.{:s}'.format(args.source_lang)),
tgt_file=os.path.join(args.data, 'test.{:s}'.format(args.target_lang)),
src_dict=src_dict, tgt_dict=tgt_dict)
vis_loader = torch.utils.data.DataLoader(test_dataset, num_workers=1, collate_fn=test_dataset.collater,
batch_sampler=BatchSampler(test_dataset, None, 1, 1, 0, shuffle=False,
seed=42))
# Build model and optimization criterion
model = models.build_model(args, src_dict, tgt_dict)
if args.cuda and torch.cuda.is_available():
model = model.cuda()
model.load_state_dict(state_dict['model'])
print('Loaded a model from checkpoint {:s}'.format(args.checkpoint_path))
# Store attention weight arrays
attn_records = list()
# Iterate over the visualization set
for i, sample in enumerate(vis_loader):
if torch.cuda.is_available() and args.cuda:
for k in sample:
if type(sample[k]) == torch.Tensor:
sample[k] = sample[k].cuda()
if len(sample) == 0:
continue
# Perform forward pass
output, attn_weights = model(
sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'])
attn_records.append((sample, attn_weights))
# Only visualize the first 10 sentence pairs
if i >= 100:
break
# Generate heat-maps and store them at the designated location
if not os.path.exists(args.vis_dir):
os.makedirs(args.vis_dir)
for record_id, record in enumerate(attn_records):
# Unpack
sample, attn_map = record
src_ids = utils.strip_pad(sample['src_tokens'].data, tgt_dict.pad_idx)
tgt_ids = utils.strip_pad(sample['tgt_inputs'].data, tgt_dict.pad_idx)
# Convert indices into word tokens
src_str = src_dict.string(src_ids).split(' ') + ['<EOS>']
tgt_str = tgt_dict.string(tgt_ids).split(' ') + ['<EOS>']
# Generate heat-maps
if torch.cuda.is_available() and args.cuda:
attn_map = attn_map.squeeze(dim=0).transpose(1, 0).detach().cpu().numpy()
else:
attn_map = attn_map.squeeze(dim=0).transpose(1, 0).detach().numpy()
attn_df = pd.DataFrame(attn_map,
index=src_str,
columns=tgt_str)
sns.heatmap(attn_df, cmap='Blues', linewidths=0.25, vmin=0.0, vmax=1.0, xticklabels=True, yticklabels=True,
fmt='.3f')
plt.yticks(rotation=0)
plot_path = os.path.join(
args.vis_dir, 'sentence_{:d}.png'.format(record_id))
plt.savefig(plot_path, dpi='figure', pad_inches=1, bbox_inches='tight')
plt.clf()
print('Done! Visualised attention maps have been saved to the \'{:s}\' directory!'.format(
args.vis_dir))
if __name__ == '__main__':
args = get_args()
main(args)