-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathindex.html
353 lines (303 loc) · 51.7 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>10_Seq2Seq_Attention</title>
<link rel="stylesheet" href="https://stackedit.io/style.css" />
</head>
<body class="stackedit">
<div class="stackedit__html"><h1 id="seq2seq-attention">10 Seq2Seq Attention</h1>
<h2 id="assignment">Assignment</h2>
<ol>
<li>Replace the embeddings of this session’s code with GloVe embeddings</li>
<li>Compare your results with this session’s code.</li>
<li>Upload to a public GitHub repo and proceed to Session 10 Assignment Solutions where these questions are asked:
<ol>
<li>Share the link to your README file’s public repo for this assignment. Expecting a minimum 500-word write-up on your learnings. Expecting you to compare your results with the code covered in the class. - 750 Points</li>
<li>Share the link to your main notebook with training logs - 250 Points</li>
</ol>
</li>
</ol>
<h2 id="solution">Solution</h2>
<table>
<thead>
<tr>
<th></th>
<th>NBViewer</th>
<th>Google Colab</th>
</tr>
</thead>
<tbody>
<tr>
<td>Old Code - French to English</td>
<td><a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/END2_Translation_using_Seq2Seq_and_Attention.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a></td>
<td><a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/END2_Translation_using_Seq2Seq_and_Attention.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></td>
</tr>
<tr>
<td><strong>New Code</strong> - English to French w/ GloVe Embeddings</td>
<td><a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/Seq2Seq_Attention.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a></td>
<td><a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/Seq2Seq_Attention.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></td>
</tr>
</tbody>
</table><p>If someday PyTorch decides to remove the <code>data.zip</code> file, I’ve added it to <a href="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/data.zip">this repository</a>.</p>
<h3 id="creating-the-dataset">Creating the Dataset</h3>
<p>Some of the dataset code was changed so that it supports the PyTorch Lightning Data Module and Model, like</p>
<p>Use <code>build_vocab_from_iterator</code> to build a <code>Vocab</code> object, this will later be used with pretrained word embedding, to map the vocab to the GloVe’s vocab.</p>
<pre class=" language-python"><code class="prism language-python"> <span class="token keyword">def</span> <span class="token function">prepare_langs</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> lang_file<span class="token operator">=</span><span class="token string">'eng-fra'</span><span class="token punctuation">,</span> reverse<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">with</span> urlopen<span class="token punctuation">(</span>self<span class="token punctuation">.</span>zip_url<span class="token punctuation">)</span> <span class="token keyword">as</span> f<span class="token punctuation">:</span>
<span class="token keyword">with</span> BytesIO<span class="token punctuation">(</span>f<span class="token punctuation">.</span>read<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">as</span> b<span class="token punctuation">,</span> ZipFile<span class="token punctuation">(</span>b<span class="token punctuation">)</span> <span class="token keyword">as</span> datazip<span class="token punctuation">:</span>
lang1<span class="token punctuation">,</span> lang2 <span class="token operator">=</span> lang_file<span class="token punctuation">.</span>split<span class="token punctuation">(</span><span class="token string">'-'</span><span class="token punctuation">)</span>
pairs <span class="token operator">=</span> readPairs<span class="token punctuation">(</span>datazip<span class="token punctuation">,</span> lang1<span class="token punctuation">,</span> lang2<span class="token punctuation">,</span> reverse<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Read %s sentence pairs"</span> <span class="token operator">%</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pairs<span class="token punctuation">)</span><span class="token punctuation">)</span>
pairs <span class="token operator">=</span> filterPairs<span class="token punctuation">(</span>pairs<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Trimmed to %s sentence pairs"</span> <span class="token operator">%</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pairs<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Counting words..."</span><span class="token punctuation">)</span>
input_sentences<span class="token punctuation">,</span> target_sentences <span class="token operator">=</span> <span class="token builtin">zip</span><span class="token punctuation">(</span><span class="token operator">*</span>pairs<span class="token punctuation">)</span>
input_lang <span class="token operator">=</span> build_vocab_from_iterator<span class="token punctuation">(</span>
<span class="token punctuation">[</span>sentence<span class="token punctuation">.</span>split<span class="token punctuation">(</span><span class="token string">' '</span><span class="token punctuation">)</span> <span class="token keyword">for</span> sentence <span class="token keyword">in</span> input_sentences<span class="token punctuation">]</span><span class="token punctuation">,</span>
specials<span class="token operator">=</span>special_tokens
<span class="token punctuation">)</span>
output_lang <span class="token operator">=</span> build_vocab_from_iterator<span class="token punctuation">(</span>
<span class="token punctuation">[</span>sentence<span class="token punctuation">.</span>split<span class="token punctuation">(</span><span class="token string">' '</span><span class="token punctuation">)</span> <span class="token keyword">for</span> sentence <span class="token keyword">in</span> target_sentences<span class="token punctuation">]</span><span class="token punctuation">,</span>
specials<span class="token operator">=</span>special_tokens
<span class="token punctuation">)</span>
<span class="token builtin">setattr</span><span class="token punctuation">(</span>input_lang<span class="token punctuation">,</span> <span class="token string">'name'</span><span class="token punctuation">,</span> lang2 <span class="token keyword">if</span> reverse <span class="token keyword">else</span> lang1<span class="token punctuation">)</span>
<span class="token builtin">setattr</span><span class="token punctuation">(</span>output_lang<span class="token punctuation">,</span> <span class="token string">'name'</span><span class="token punctuation">,</span> lang1 <span class="token keyword">if</span> reverse <span class="token keyword">else</span> lang2<span class="token punctuation">)</span>
<span class="token builtin">setattr</span><span class="token punctuation">(</span>input_lang<span class="token punctuation">,</span> <span class="token string">'n_words'</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>input_lang<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token builtin">setattr</span><span class="token punctuation">(</span>output_lang<span class="token punctuation">,</span> <span class="token string">'n_words'</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>output_lang<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Counted words:"</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>input_lang<span class="token punctuation">.</span>name<span class="token punctuation">,</span> input_lang<span class="token punctuation">.</span>n_words<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>output_lang<span class="token punctuation">.</span>name<span class="token punctuation">,</span> output_lang<span class="token punctuation">.</span>n_words<span class="token punctuation">)</span>
<span class="token keyword">return</span> input_lang<span class="token punctuation">,</span> output_lang<span class="token punctuation">,</span> pairs
</code></pre>
<h3 id="encoder-and-decoder">Encoder and Decoder</h3>
<p>The encoder of a seq2seq network is a RNN that outputs some value for every word from the input sentence. For every input word the encoder outputs a vector and a hidden state, and uses the hidden state for the next input word.</p>
<p>If only the context vector is passed between the encoder and decoder, that single vector carries the burden of encoding the entire sentence.</p>
<p>Attention allows the decoder network to “focus” on a different part of the encoder’s outputs for every step of the decoder’s own outputs. First we calculate a set of <em>attention weights</em>. These will be multiplied by the encoder output vectors to create a weighted combination. The result (called <code>attn_applied</code> in the code) should contain information about that specific part of the input sequence, and thus help the decoder choose the right output words.</p>
<p>Calculating the attention weights is done with another feed-forward layer <code>attn</code>, using the decoder’s input and hidden state as inputs. Because there are sentences of all sizes in the training data, to actually create and train this layer we have to choose a maximum sentence length (input length, for encoder outputs) that it can apply to. Sentences of the maximum length will use all the attention weights, while shorter sentences will only use the first few.</p>
<h3 id="using-pretrained-glove-embeddings">Using Pretrained <code>GloVe</code> Embeddings</h3>
<p><strong>Glo</strong>bal <strong>Ve</strong>ctors for Word Representation, or GloVe, is an “<a href="https://nlp.stanford.edu/projects/glove/">unsupervised learning algorithm for obtaining vector representations for words.</a>” Simply put, GloVe allows us to take a corpus of text, and intuitively transform each word in that corpus into a position in a high-dimensional space. This means that similar words will be placed together.</p>
<p>I found this nice way for using <code>Embeddings</code> with <code>GloVe</code> <code>Vectors</code></p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">from</span> torchtext<span class="token punctuation">.</span>vocab <span class="token keyword">import</span> GloVe<span class="token punctuation">,</span> vocab
<span class="token keyword">from</span> torchtext<span class="token punctuation">.</span>datasets <span class="token keyword">import</span> AG_NEWS
<span class="token keyword">from</span> torchtext<span class="token punctuation">.</span>data<span class="token punctuation">.</span>utils <span class="token keyword">import</span> get_tokenizer
<span class="token keyword">import</span> torch
<span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn
<span class="token comment">#define your model that accepts pretrained embeddings </span>
<span class="token keyword">class</span> <span class="token class-name">TextClassificationModel</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> pretrained_embeddings<span class="token punctuation">,</span> num_class<span class="token punctuation">,</span> freeze_embeddings <span class="token operator">=</span> <span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token builtin">super</span><span class="token punctuation">(</span>TextClassificationModel<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>EmbeddingBag<span class="token punctuation">.</span>from_pretrained<span class="token punctuation">(</span>pretrained_embeddings<span class="token punctuation">,</span> freeze <span class="token operator">=</span> freeze_embeddings<span class="token punctuation">,</span> sparse<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>pretrained_embeddings<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>init_weights<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">init_weights</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
initrange <span class="token operator">=</span> <span class="token number">0.5</span>
self<span class="token punctuation">.</span>fc<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data<span class="token punctuation">.</span>uniform_<span class="token punctuation">(</span><span class="token operator">-</span>initrange<span class="token punctuation">,</span> initrange<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>fc<span class="token punctuation">.</span>bias<span class="token punctuation">.</span>data<span class="token punctuation">.</span>zero_<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> text<span class="token punctuation">,</span> offsets<span class="token punctuation">)</span><span class="token punctuation">:</span>
embedded <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding<span class="token punctuation">(</span>text<span class="token punctuation">,</span> offsets<span class="token punctuation">)</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>embedded<span class="token punctuation">)</span>
train_iter <span class="token operator">=</span> AG_NEWS<span class="token punctuation">(</span>split <span class="token operator">=</span> <span class="token string">'train'</span><span class="token punctuation">)</span>
num_class <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span><span class="token builtin">set</span><span class="token punctuation">(</span><span class="token punctuation">[</span>label <span class="token keyword">for</span> <span class="token punctuation">(</span>label<span class="token punctuation">,</span> _<span class="token punctuation">)</span> <span class="token keyword">in</span> train_iter<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
unk_token <span class="token operator">=</span> <span class="token string">"<unk>"</span>
unk_index <span class="token operator">=</span> <span class="token number">0</span>
glove_vectors <span class="token operator">=</span> GloVe<span class="token punctuation">(</span><span class="token punctuation">)</span>
glove_vocab <span class="token operator">=</span> vocab<span class="token punctuation">(</span>glove_vectors<span class="token punctuation">.</span>stoi<span class="token punctuation">)</span>
glove_vocab<span class="token punctuation">.</span>insert_token<span class="token punctuation">(</span><span class="token string">"<unk>"</span><span class="token punctuation">,</span>unk_index<span class="token punctuation">)</span>
<span class="token comment">#this is necessary otherwise it will throw runtime error if OOV token is queried </span>
glove_vocab<span class="token punctuation">.</span>set_default_index<span class="token punctuation">(</span>unk_index<span class="token punctuation">)</span>
pretrained_embeddings <span class="token operator">=</span> glove_vectors<span class="token punctuation">.</span>vectors
pretrained_embeddings <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span>pretrained_embeddings<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>pretrained_embeddings<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment">#instantiate model with pre-trained glove vectors</span>
glove_model <span class="token operator">=</span> TextClassificationModel<span class="token punctuation">(</span>pretrained_embeddings<span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>
tokenizer <span class="token operator">=</span> get_tokenizer<span class="token punctuation">(</span><span class="token string">"basic_english"</span><span class="token punctuation">)</span>
train_iter <span class="token operator">=</span> AG_NEWS<span class="token punctuation">(</span>split <span class="token operator">=</span> <span class="token string">'train'</span><span class="token punctuation">)</span>
example_text <span class="token operator">=</span> <span class="token builtin">next</span><span class="token punctuation">(</span>train_iter<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span>
tokens <span class="token operator">=</span> tokenizer<span class="token punctuation">(</span>example_text<span class="token punctuation">)</span>
indices <span class="token operator">=</span> glove_vocab<span class="token punctuation">(</span>tokens<span class="token punctuation">)</span>
text_input <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>indices<span class="token punctuation">)</span>
offset_input <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
model_output <span class="token operator">=</span> glove_model<span class="token punctuation">(</span>text_input<span class="token punctuation">,</span> offset_input<span class="token punctuation">)</span>
</code></pre>
<p><a href="https://github.com/pytorch/text/issues/1350">Source</a></p>
<p>And for using Pretrained Embeddings with an existing Vocab object</p>
<pre class=" language-python"><code class="prism language-python">min_freq <span class="token operator">=</span> <span class="token number">5</span>
special_tokens <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token string">'<unk>'</span><span class="token punctuation">,</span> <span class="token string">'<pad>'</span><span class="token punctuation">]</span>
vocab <span class="token operator">=</span> torchtext<span class="token punctuation">.</span>vocab<span class="token punctuation">.</span>build_vocab_from_iterator<span class="token punctuation">(</span>train_data<span class="token punctuation">[</span><span class="token string">'tokens'</span><span class="token punctuation">]</span><span class="token punctuation">,</span>
min_freq<span class="token operator">=</span>min_freq<span class="token punctuation">,</span>
specials<span class="token operator">=</span>special_tokens<span class="token punctuation">)</span>
<span class="token comment"># train_data['tokens'] is a list of a list of strings, i.e. [['hello', 'world'], ['goodbye', 'moon']], where ['hello', 'moon'] is the tokens corresponding to the first example in the training set.</span>
pretrained_vectors <span class="token operator">=</span> torchtext<span class="token punctuation">.</span>vocab<span class="token punctuation">.</span>FastText<span class="token punctuation">(</span><span class="token punctuation">)</span>
pretrained_embedding <span class="token operator">=</span> pretrained_vectors<span class="token punctuation">.</span>get_vecs_by_tokens<span class="token punctuation">(</span>vocab<span class="token punctuation">.</span>get_itos<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment"># vocab.get_itos() returns a list of strings (tokens), where the token at the i'th position is what you get from doing vocab[token]</span>
<span class="token comment"># get_vecs_by_tokens gets the pre-trained vector for each string when given a list of strings</span>
<span class="token comment"># therefore pretrained_embedding is a fully "aligned" embedding matrix</span>
<span class="token keyword">class</span> <span class="token class-name">NBoW</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> vocab_size<span class="token punctuation">,</span> embedding_dim<span class="token punctuation">,</span> output_dim<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token builtin">super</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>Embedding<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embedding_dim<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> output_dim<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> text<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token comment"># text = [batch size, seq len]</span>
embedded <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding<span class="token punctuation">(</span>text<span class="token punctuation">)</span>
<span class="token comment"># embedded = [batch size, seq len, embedding dim]</span>
pooled <span class="token operator">=</span> embedded<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>dim<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>
<span class="token comment"># pooled = [batch size, embedding dim]</span>
prediction <span class="token operator">=</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>pooled<span class="token punctuation">)</span>
<span class="token comment"># prediction = [batch size, output dim]</span>
<span class="token keyword">return</span> prediction
vocab_size <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>vocab<span class="token punctuation">)</span>
embedding_dim <span class="token operator">=</span> <span class="token number">300</span>
output_dim <span class="token operator">=</span> n_classes
model <span class="token operator">=</span> NBoW<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embedding_dim<span class="token punctuation">,</span> output_dim<span class="token punctuation">,</span> pad_index<span class="token punctuation">)</span>
<span class="token comment"># super basic model here, important thing is the nn.Embedding layer that needs to have an embedding layer that is initialized as nn.Embedding(vocab_size, embedding_dim) with embedding_dim = 300 as that's the dimensions of the FastText embedding</span>
model<span class="token punctuation">.</span>embedding<span class="token punctuation">.</span>weight<span class="token punctuation">.</span>data <span class="token operator">=</span> pretrained_embedding
<span class="token comment"># overwrite the model's initial embedding matrix weights with that of the pre-trained embeddings from FastText</span>
</code></pre>
<p>And this is how I integrated GloVe Embeddings into this assignment</p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">class</span> <span class="token class-name">EncoderRNN</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> input_size<span class="token punctuation">,</span> hidden_size<span class="token punctuation">,</span> use_pretrained<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> vocab_itos<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token builtin">super</span><span class="token punctuation">(</span>EncoderRNN<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>hidden_size <span class="token operator">=</span> hidden_size
<span class="token keyword">if</span> use_pretrained <span class="token operator">and</span> vocab_itos <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
<span class="token keyword">raise</span> ValueError<span class="token punctuation">(</span><span class="token string">'`use_pretained=True` with `vocab_itos=None`, please provide the vocab itos List'</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> use_pretrained<span class="token punctuation">:</span>
glove_vec <span class="token operator">=</span> torchtext<span class="token punctuation">.</span>vocab<span class="token punctuation">.</span>GloVe<span class="token punctuation">(</span>name<span class="token operator">=</span><span class="token string">'6B'</span><span class="token punctuation">)</span>
glove_emb <span class="token operator">=</span> glove_vec<span class="token punctuation">.</span>get_vecs_by_tokens<span class="token punctuation">(</span>vocab_itos<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>Embedding<span class="token punctuation">.</span>from_pretrained<span class="token punctuation">(</span>glove_emb<span class="token punctuation">,</span> padding_idx<span class="token operator">=</span>PAD_token<span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>Embedding<span class="token punctuation">(</span>input_size<span class="token punctuation">,</span> hidden_size<span class="token punctuation">)</span>
<span class="token keyword">assert</span> self<span class="token punctuation">.</span>embedding<span class="token punctuation">.</span>embedding_dim <span class="token operator">==</span> hidden_size<span class="token punctuation">,</span>\
f<span class="token string">'hidden_size must equal embedding dim, found hidden_size={hidden_size}, embedding_dim={self.embedding.embedding_dim}'</span>
self<span class="token punctuation">.</span>gru <span class="token operator">=</span> nn<span class="token punctuation">.</span>GRU<span class="token punctuation">(</span>hidden_size<span class="token punctuation">,</span> hidden_size<span class="token punctuation">)</span>
</code></pre>
<h3 id="teacher-forcing">Teacher Forcing</h3>
<blockquote>
<p>Consider the task of sequence prediction, so you want to predict the next element of a sequence <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>e</mi><mi>t</mi></msub></mrow><annotation encoding="application/x-tex">e_t</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.280556em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> given the previous elements of this sequence <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>e</mi><mrow><mi>t</mi><mtext>−</mtext><mn>1</mn></mrow></msub><mo separator="true">,</mo><msub><mi>e</mi><mrow><mi>t</mi><mtext>−</mtext><mn>2</mn></mrow></msub><mo separator="true">,</mo><mo>…</mo><mo separator="true">,</mo><msub><mi>e</mi><mn>1</mn></msub><mo>=</mo><msub><mi>e</mi><mrow><mi>t</mi><mtext>−</mtext><mn>1</mn><mo>:</mo><mn>1</mn></mrow></msub></mrow><annotation encoding="application/x-tex">e_{t−1},e_{t−2},…,e_{1}=e_{t−1:1}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.638891em; vertical-align: -0.208331em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mord mtight">−1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mord mtight">−2</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.638891em; vertical-align: -0.208331em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mord mtight">−1</span><span class="mrel mtight">:</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span></span></span></span></span>. Teacher forcing is about forcing the predictions to be based on correct histories (i.e. the correct sequence of past elements) rather than predicted history (which may not be correct). To be more concrete, let <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>e</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">e_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> denote the <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>i</mi></mrow><annotation encoding="application/x-tex">i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.65952em; vertical-align: 0em;"></span><span class="mord mathnormal">i</span></span></span></span></span>th predicted element of the sequence and let <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>e</mi><mi>i</mi></msub></mrow><annotation encoding="application/x-tex">e_i</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span> be the corresponding ground-truth. Then, if you use teacher forcing, to predict etet, rather than using <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>e</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn><mo>:</mo><mn>1</mn></mrow></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex">\hat{e_{t-1:1}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.902771em; vertical-align: -0.208331em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.69444em;"><span class="" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span><span class="mrel mtight">:</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span></span><span class="" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="accent-body" style="left: -0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span></span></span></span>, you would use <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><msub><mi>e</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn><mo>:</mo><mn>1</mn></mrow></msub><mo>^</mo></mover></mrow><annotation encoding="application/x-tex">\hat{e_{t-1:1}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.902771em; vertical-align: -0.208331em;"></span><span class="mord accent"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.69444em;"><span class="" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span><span class="mrel mtight">:</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span></span><span class="" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="accent-body" style="left: -0.25em;"><span class="mord">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span></span></span></span>. <a href="https://ai.stackexchange.com/questions/18006/what-is-teacher-forcing">ai.stackexchange</a></p>
</blockquote>
<p>Here’s another explanation</p>
<blockquote>
<p><em>Teacher forcing is like a teacher correcting a student as the student gets trained on a new concept. As the right input is given by the teacher to the student during training, student will learn the new concept faster and efficiently.</em></p>
</blockquote>
<p>When training with teacher forcing, at random we choose to do forcing, in this we supply the actual output of the previous time step instead of the predicted output from the previous time step of the encoder.</p>
<pre class=" language-python"><code class="prism language-python"> <span class="token keyword">if</span> use_teacher_forcing<span class="token punctuation">:</span>
<span class="token comment"># Teacher forcing: Feed the target as the next input</span>
<span class="token keyword">for</span> di <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>target_length<span class="token punctuation">)</span><span class="token punctuation">:</span>
decoder_output<span class="token punctuation">,</span> decoder_hidden<span class="token punctuation">,</span> decoder_attention <span class="token operator">=</span> self<span class="token punctuation">.</span>attn_decoder<span class="token punctuation">(</span>
decoder_input<span class="token punctuation">,</span> decoder_hidden<span class="token punctuation">,</span> encoder_outputs<span class="token punctuation">)</span>
loss <span class="token operator">+=</span> self<span class="token punctuation">.</span>criterion<span class="token punctuation">(</span>decoder_output<span class="token punctuation">,</span> target_tensor<span class="token punctuation">[</span>di<span class="token punctuation">]</span><span class="token punctuation">)</span>
decoder_input <span class="token operator">=</span> target_tensor<span class="token punctuation">[</span>di<span class="token punctuation">]</span> <span class="token comment"># Teacher forcing</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token comment"># Without teacher forcing: use its own predictions as the next input</span>
<span class="token keyword">for</span> di <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>target_length<span class="token punctuation">)</span><span class="token punctuation">:</span>
decoder_output<span class="token punctuation">,</span> decoder_hidden<span class="token punctuation">,</span> decoder_attention <span class="token operator">=</span> self<span class="token punctuation">.</span>attn_decoder<span class="token punctuation">(</span>
decoder_input<span class="token punctuation">,</span> decoder_hidden<span class="token punctuation">,</span> encoder_outputs<span class="token punctuation">)</span>
topv<span class="token punctuation">,</span> topi <span class="token operator">=</span> decoder_output<span class="token punctuation">.</span>topk<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span>
decoder_input <span class="token operator">=</span> topi<span class="token punctuation">.</span>squeeze<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># detach from history as input</span>
loss <span class="token operator">+=</span> self<span class="token punctuation">.</span>criterion<span class="token punctuation">(</span>decoder_output<span class="token punctuation">,</span> target_tensor<span class="token punctuation">[</span>di<span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> decoder_input<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">==</span> EOS_token<span class="token punctuation">:</span>
<span class="token keyword">break</span>
</code></pre>
<p>But why do we really have to do this?</p>
<p>Lets assume we have an slightly trained Network for the Encoder and Decoder</p>
<p>And these are our sentences</p>
<p><code>SRC: <SOS> hi satyajit how are you ? <EOS></code><br>
<code>TGT: <SOS> salut satyajit comment vas-tu ? <EOS></code></p>
<p>After the entire <code>SRC</code> is sent to the encoder word by word, we will have some embeddings, which would be <em>meaningless</em> since the model is not trained that well</p>
<p>This is what the decoder will see</p>
<pre><code>INPUT PRED
[SOS] a
[SOS] a a ??
[SOS] a ?? a ?? ??
</code></pre>
<p>See how difficult it is for the decoder rnn to decode meaningless sentences, and this makes the model unstable and very difficult to learn, and this is why we randomly use the target sentence itself to train the decoder</p>
<pre><code>INPUT TEACHER FORCED PRED
[SOS] ??
[SOS] ?? ??
[SOS] ?? satyajit ?? satyajit
[SOS] ?? satyajit how ?? satyajit comment
</code></pre>
<p>Something like above, since the decoder is fed with the actual target words as the previous input, it gets to learn better.</p>
<h3 id="further-possible-improvement">Further possible improvement</h3>
<ul>
<li>The model does not support batching, which greatly would improve performance and also loss if done.</li>
<li>The optimizer here used is SGD, which is generally not preferred for FC Networks, so Adam could have been used here.</li>
<li>I found <a href="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/seq2seq-translation/seq2seq-translation-batched.ipynb">this</a> really good notebook, that shows diffferent kinds of attention models, and guess what ! it is batched !</li>
</ul>
<h3 id="sample-output">Sample Output</h3>
<pre><code>[KEY: > input, = target, < output]
> he s not going .
= il ne s y rend pas .
< il ne s y y . <EOS>
> we re not happy .
= nous ne sommes pas heureuses .
< nous ne sommes pas heureux . <EOS>
> we re too old .
= nous sommes trop vieux .
< nous sommes trop vieux . <EOS>
> i m not a crook .
= je ne suis pas un escroc .
< je ne suis pas un . <EOS>
> you re free of all responsibility .
= vous etes liberee de toute responsabilite .
< vous etes liberee de toute responsabilite . <EOS>
> i m sorry we re completely sold out .
= je suis desole nous avons ete devalises .
< je suis desole nous avons tout vendu . <EOS>
> you are the one .
= vous etes l elu .
< vous etes celui la . <EOS>
> they re all dead .
= elles sont toutes mortes .
< ils sont tous des . <EOS>
> he s always late for school .
= il est toujours en retard a l ecole .
< il est toujours en retard a l ecole . <EOS>
> he is busy .
= il a a faire .
< il a l l l l <EOS>
</code></pre>
<hr>
<h3 id="some-attention-visualizations">Some Attention Visualizations</h3>
<pre><code>input = i m very impressed by your work .
output = je suis tres par par votre travail . <EOS>
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/attentions/attn1.png?raw=true" alt="attn1"></p>
<hr>
<pre><code>input = we re smart .
output = nous sommes intelligents . <EOS>
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/attentions/attn2.png?raw=true" alt="attn2"></p>
<hr>
<pre><code>input = i m still hungry .
output = j ai toujours faim . <EOS>
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/attentions/attn3.png?raw=true" alt="attn3"></p>
<hr>
<pre><code>input = he is very eager to go there .
output = il est tres sensible de partir . <EOS>
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/attentions/attn4.png?raw=true%5C" alt="attn4"></p>
<hr>
<pre><code>input = i m sorry we re completely sold out .
output = je suis desole nous avons tout vendu . <EOS>
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/10_Seq2Seq_Attention/attentions/attn5.png?raw=true" alt="attn5"></p>
<hr>
<p align="center">
<iframe src="https://giphy.com/embed/dz1iM8gU3RhzQy2MC7" width="480" height="392" class="giphy-embed" allowfullscreen=""></iframe></p><p><a href="https://giphy.com/gifs/memecandy-dz1iM8gU3RhzQy2MC7"></a></p>
<hr>
<p align="center">
Thanks for reading, have a great day 😄
</p>
<p align="center">
<iframe src="https://open.spotify.com/embed/track/3jPWd7NpYoaGVUSbJh9Xca" width="300" height="380"></iframe>
</p>
<hr>
<p align="center">
:wq satyajit
</p>
</div>
</body>
</html>