forked from kondilidisn/hotpotqa_rev
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcode_notes.txt
executable file
·82 lines (61 loc) · 3.15 KB
/
code_notes.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
Here are some tensor sizes put together :
The following numbers imply : Batch size = 1 , Number of sentences = 65 , Words on sentences in total : 1767 , Words on the question : 17
context_idxs : torch.Size([1, 1767])
ques_idxs : torch.Size([1, 17])
context_char_idxs : torch.Size([1, 1767, 16])
ques_char_idxs : torch.Size([1, 17, 16])
context_lens : tensor([1767])
y1 : torch.Size([1])
y2 : torch.Size([1])
q_type : torch.Size([1])
is_support : torch.Size([1, 65])
start_mapping : torch.Size([1, 1767, 65])
end_mapping : torch.Size([1, 1767, 65])
all_mapping : torch.Size([1, 1767, 65])
Inside forward:
sp_output : torch.Size([1, 1767, 160])
start_output : torch.Size([1, 65, 80])
end_output : torch.Size([1, 65, 80])
sp_output[:,:,self.hidden:] : torch.Size([1, 1767, 80])
sp_output concatenated (start_mapping , end_mapping): torch.Size([1, 65, 160])
sp_output_aux : torch.Size([1, 65, 1])
sp_output_t : torch.Size([1, 65, 1])
predict_support : torch.Size([1, 65, 2])
predicted:
logit1 : torch.Size([1, 1767])
logit2 : torch.Size([1, 1767])
predict_type : torch.Size([1, 3])
predict_support : torch.Size([1, 65, 2])
These 3 numbers change from on batch to another :
Number of sentences = 65 , Words on sentences in total : 1767 , Words on the question : 17
The mapping variables (start_mapping, end_mapping) are one-hot vectors, that simply yell where every sentence begin and where every sentence ends
The all_mapping variable, just has 1 from the start to the end of some sentence, indicating with 1 all words that belong to that sentence
prints inside forward, upper part :
print('context_mask: ',context_mask.shape)
print('ques_mask: ',ques_mask.shape)
context_mask: torch.Size([1, 1307])
ques_mask: torch.Size([1, 29])
context_ch = self.char_emb(context_char_idxs.contiguous().view(-1, char_size)).view(bsz * para_size, char_size, -1)
ques_ch = self.char_emb(ques_char_idxs.contiguous().view(-1, char_size)).view(bsz * ques_size, char_size, -1)
print('context_ch: ',context_ch.shape)
print('ques_ch: ',ques_ch.shape)
context_ch: torch.Size([1307, 16, 8])
ques_ch: torch.Size([29, 16, 8])
context_ch = self.char_cnn(context_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, para_size, -1)
ques_ch = self.char_cnn(ques_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, ques_size, -1)
print('context_ch: ',context_ch.shape)
print('ques_ch: ',ques_ch.shape)
context_ch: torch.Size([1, 1307, 100])
ques_ch: torch.Size([1, 29, 100])
context_word = self.word_emb(context_idxs)
ques_word = self.word_emb(ques_idxs)
print('context_word: ',context_word.shape)
print('ques_word: ',ques_word.shape)
context_word: torch.Size([1, 1307, 300])
ques_word: torch.Size([1, 29, 300])
context_output = torch.cat([context_word, context_ch], dim=2)
ques_output = torch.cat([ques_word, ques_ch], dim=2)
print('context_output: ',context_output.shape)
print('ques_output: ',ques_output.shape)
context_output: torch.Size([1, 1307, 400])
ques_output: torch.Size([1, 29, 400])