|
29 | 29 |
|
30 | 30 | ######################################################################
|
31 | 31 | # In this tutorial, we train a ``nn.TransformerEncoder`` model on a
|
32 |
| -# language modeling task. Please note that this tutorial does not cover |
| 32 | +# causal language modeling task. Please note that this tutorial does not cover |
33 | 33 | # the training of `nn.TransformerDecoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html#torch.nn.TransformerDecoder>`__, as depicted in
|
34 | 34 | # the right half of the diagram above. The language modeling task is to assign a
|
35 | 35 | # probability for the likelihood of a given word (or a sequence of words)
|
|
41 | 41 | # Along with the input sequence, a square attention mask is required because the
|
42 | 42 | # self-attention layers in ``nn.TransformerDecoder`` are only allowed to attend
|
43 | 43 | # the earlier positions in the sequence. For the language modeling task, any
|
44 |
| -# tokens on the future positions should be masked. To produce a probability |
45 |
| -# distribution over output words, the output of the ``nn.TransformerEncoder`` |
| 44 | +# tokens on the future positions should be masked. This masking, combined with fact that |
| 45 | +# the output embeddings are offset with later positions ensures that the |
| 46 | +# predictions for position i can depend only on the known outputs at positions less than i. |
| 47 | +# To produce a probability distribution over output words, the output of the ``nn.TransformerEncoder`` |
46 | 48 | # model is passed through a linear layer to output unnormalized logits.
|
47 | 49 | # The log-softmax function isn't applied here due to the later use of
|
48 | 50 | # `CrossEntropyLoss <https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html>`__,
|
@@ -91,6 +93,11 @@ def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
|
91 | 93 | """
|
92 | 94 | src = self.embedding(src) * math.sqrt(self.d_model)
|
93 | 95 | src = self.pos_encoder(src)
|
| 96 | + if src_mask is None: |
| 97 | + """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf'). |
| 98 | + Unmasked positions are filled with float(0.0). |
| 99 | + """ |
| 100 | + src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device) |
94 | 101 | output = self.transformer_encoder(src, src_mask)
|
95 | 102 | output = self.linear(output)
|
96 | 103 | return output
|
|
0 commit comments