|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | from functools import partial |
8 | | -from typing import Any, Callable, Optional |
| 8 | +from typing import Any, Callable, Optional, Tuple, Union |
9 | 9 |
|
10 | 10 | import torch |
11 | 11 | from torch import nn, Tensor |
| 12 | +from torchmultimodal.modules.layers.attention import MultiHeadAttention, SelfAttention |
| 13 | +from torchmultimodal.modules.layers.mlp import MLP |
| 14 | +from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm |
12 | 15 |
|
13 | 16 | from torchmultimodal.modules.layers.transformer import TransformerOutput |
14 | 17 |
|
@@ -75,6 +78,223 @@ def forward( |
75 | 78 | ) |
76 | 79 |
|
77 | 80 |
|
| 81 | +class TransformerEncoderLayer(nn.Module): |
| 82 | + """Transformer encoder layer is made up of multihead self-attention and feedforward blocks, |
| 83 | + based on the architecture in "Attention Is All You Need" (Vaswani et al. 2017). Similar to |
| 84 | + ``nn.TransformerEncoderLayer``, but uses a custom ``MultiHeadAttention`` that supports |
| 85 | + n-dimensional inputs (including sequences, images, video) and head-masking. |
| 86 | +
|
| 87 | + Attributes: |
| 88 | + d_model (int): size of hidden dimension of input |
| 89 | + n_head (int): number of attention heads |
| 90 | + dim_feedforward (int): size of hidden dimension of feedforward network |
| 91 | + dropout (float): dropout probability for all dropouts. Defaults to 0. |
| 92 | + activation (Callable): activation function in feedforward network. Defaults to ``nn.ReLU``. |
| 93 | + layer_norm_eps (float): the eps value in layer norms. Default is 1e-12. |
| 94 | + norm_first (bool): if True, layer norm is done prior to each of self-attention, cross-attention, |
| 95 | + and feedforward. Otherwise, layer norm is done after. |
| 96 | +
|
| 97 | + Args: |
| 98 | + hidden_states (Tensor): input tensor of shape [b, d1, ..., dn, c] to calculate self-attention on. |
| 99 | + attention_mask (Tensor, optional): mask to be applied to self-attention inputs, ``hidden_states``. See |
| 100 | + ``MultiHeadAttention`` for shape requirements. |
| 101 | + head_mask (Tensor, optional): mask to be applied to self-attention inputs after softmax and dropout, |
| 102 | + before matrix multiplication with values. See ``MultiHeadAttention`` for shape requirements. |
| 103 | + return_attn_weights (bool, optional): return attention probabilities in addition to attention output. |
| 104 | + Defaults to False. |
| 105 | + """ |
| 106 | + |
| 107 | + def __init__( |
| 108 | + self, |
| 109 | + d_model: int, |
| 110 | + n_head: int, |
| 111 | + dim_feedforward: int, |
| 112 | + dropout: float = 0.0, |
| 113 | + activation: Callable[..., nn.Module] = nn.ReLU, |
| 114 | + layer_norm_eps: float = 1e-12, |
| 115 | + norm_first: bool = False, |
| 116 | + ) -> None: |
| 117 | + super().__init__() |
| 118 | + # attention block |
| 119 | + self.attention = MultiHeadAttention( |
| 120 | + dim_q=d_model, |
| 121 | + dim_kv=d_model, |
| 122 | + n_head=n_head, |
| 123 | + attn_module=SelfAttention(dropout), |
| 124 | + ) |
| 125 | + self.attention_dropout = nn.Dropout(dropout) |
| 126 | + # feedforward block |
| 127 | + self.feedforward = MLP( |
| 128 | + d_model, d_model, dim_feedforward, dropout=dropout, activation=activation |
| 129 | + ) |
| 130 | + self.feedforward_dropout = nn.Dropout(dropout) |
| 131 | + # layernorms |
| 132 | + self.attention_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps) |
| 133 | + self.feedforward_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps) |
| 134 | + self.norm_first = norm_first |
| 135 | + |
| 136 | + def _attention_block( |
| 137 | + self, |
| 138 | + hidden_states: Tensor, |
| 139 | + attention_mask: Optional[Tensor] = None, |
| 140 | + head_mask: Optional[Tensor] = None, |
| 141 | + ) -> Tuple[Tensor, Tensor]: |
| 142 | + output, attn_weights = self.attention( |
| 143 | + hidden_states, |
| 144 | + attention_mask=attention_mask, |
| 145 | + head_mask=head_mask, |
| 146 | + return_attn_weights=True, |
| 147 | + ) |
| 148 | + output = self.attention_dropout(output) |
| 149 | + return output, attn_weights |
| 150 | + |
| 151 | + def _feedforward_block(self, hidden_states: Tensor) -> Tensor: |
| 152 | + h = self.feedforward(hidden_states) |
| 153 | + h = self.feedforward_dropout(h) |
| 154 | + return h |
| 155 | + |
| 156 | + def _forward_prenorm( |
| 157 | + self, |
| 158 | + hidden_states: Tensor, |
| 159 | + attention_mask: Optional[Tensor] = None, |
| 160 | + head_mask: Optional[Tensor] = None, |
| 161 | + return_attn_weights: bool = False, |
| 162 | + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: |
| 163 | + x = hidden_states |
| 164 | + inputs = self.attention_layernorm(x) |
| 165 | + attn_output, attn_weights = self._attention_block( |
| 166 | + inputs, |
| 167 | + attention_mask=attention_mask, |
| 168 | + head_mask=head_mask, |
| 169 | + ) |
| 170 | + attn_residual = attn_output + x |
| 171 | + ff_residual = attn_residual + self._feedforward_block( |
| 172 | + self.feedforward_layernorm(attn_residual) |
| 173 | + ) |
| 174 | + if return_attn_weights: |
| 175 | + return ff_residual, attn_weights |
| 176 | + else: |
| 177 | + return ff_residual |
| 178 | + |
| 179 | + def _forward_postnorm( |
| 180 | + self, |
| 181 | + hidden_states: Tensor, |
| 182 | + attention_mask: Optional[Tensor] = None, |
| 183 | + head_mask: Optional[Tensor] = None, |
| 184 | + return_attn_weights: bool = False, |
| 185 | + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: |
| 186 | + x = hidden_states |
| 187 | + attn_output, attn_weights = self._attention_block( |
| 188 | + x, |
| 189 | + attention_mask=attention_mask, |
| 190 | + head_mask=head_mask, |
| 191 | + ) |
| 192 | + attn_residual = attn_output + x |
| 193 | + attn_residual = self.attention_layernorm(attn_residual) |
| 194 | + ff_residual = attn_residual + self._feedforward_block(attn_residual) |
| 195 | + outputs = self.feedforward_layernorm(ff_residual) |
| 196 | + if return_attn_weights: |
| 197 | + return outputs, attn_weights |
| 198 | + else: |
| 199 | + return outputs |
| 200 | + |
| 201 | + def forward( |
| 202 | + self, |
| 203 | + hidden_states: Tensor, |
| 204 | + attention_mask: Optional[Tensor] = None, |
| 205 | + head_mask: Optional[Tensor] = None, |
| 206 | + return_attn_weights: bool = False, |
| 207 | + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: |
| 208 | + if self.norm_first: |
| 209 | + return self._forward_prenorm( |
| 210 | + hidden_states, |
| 211 | + attention_mask, |
| 212 | + head_mask, |
| 213 | + return_attn_weights, |
| 214 | + ) |
| 215 | + else: |
| 216 | + return self._forward_postnorm( |
| 217 | + hidden_states, |
| 218 | + attention_mask, |
| 219 | + head_mask, |
| 220 | + return_attn_weights, |
| 221 | + ) |
| 222 | + |
| 223 | + |
| 224 | +class TransformerEncoder(nn.Module): |
| 225 | + def __init__( |
| 226 | + self, |
| 227 | + n_layer: int, |
| 228 | + d_model: int, |
| 229 | + n_head: int, |
| 230 | + dim_feedforward: int, |
| 231 | + dropout: float = 0.0, |
| 232 | + activation: Callable[..., nn.Module] = nn.ReLU, |
| 233 | + layer_norm_eps: float = 1e-12, |
| 234 | + norm_first: bool = False, |
| 235 | + final_layer_norm_eps: Optional[float] = None, |
| 236 | + ): |
| 237 | + super().__init__() |
| 238 | + self.layer = nn.ModuleList( |
| 239 | + [ |
| 240 | + TransformerEncoderLayer( |
| 241 | + d_model, |
| 242 | + n_head, |
| 243 | + dim_feedforward, |
| 244 | + dropout, |
| 245 | + activation, |
| 246 | + layer_norm_eps, |
| 247 | + norm_first, |
| 248 | + ) |
| 249 | + for _ in range(n_layer) |
| 250 | + ] |
| 251 | + ) |
| 252 | + self.final_layer_norm = None |
| 253 | + if final_layer_norm_eps: |
| 254 | + self.final_layer_norm = Fp32LayerNorm(d_model, eps=final_layer_norm_eps) |
| 255 | + |
| 256 | + def forward( |
| 257 | + self, |
| 258 | + hidden_states: Tensor, |
| 259 | + attention_mask: Optional[Tensor] = None, |
| 260 | + head_mask: Optional[Tensor] = None, |
| 261 | + return_attn_weights: bool = False, |
| 262 | + return_hidden_states: bool = False, |
| 263 | + ) -> TransformerOutput: |
| 264 | + |
| 265 | + all_hidden_states = [] if return_hidden_states else None |
| 266 | + all_self_attentions = [] if return_attn_weights else None |
| 267 | + |
| 268 | + for layer_module in self.layer: |
| 269 | + if return_hidden_states: |
| 270 | + all_hidden_states.append(hidden_states) |
| 271 | + |
| 272 | + layer_outputs = layer_module( |
| 273 | + hidden_states, |
| 274 | + attention_mask=attention_mask, |
| 275 | + head_mask=head_mask, |
| 276 | + return_attn_weights=return_attn_weights, |
| 277 | + ) |
| 278 | + |
| 279 | + if return_attn_weights: |
| 280 | + hidden_states = layer_outputs[0] |
| 281 | + all_self_attentions.append(layer_outputs[1]) |
| 282 | + else: |
| 283 | + hidden_states = layer_outputs |
| 284 | + |
| 285 | + if return_hidden_states: |
| 286 | + all_hidden_states.append(hidden_states) |
| 287 | + |
| 288 | + if self.final_layer_norm is not None: |
| 289 | + hidden_states = self.final_layer_norm(hidden_states) |
| 290 | + |
| 291 | + return TransformerOutput( |
| 292 | + last_hidden_state=hidden_states, |
| 293 | + hidden_states=all_hidden_states, |
| 294 | + attentions=all_self_attentions, |
| 295 | + ) |
| 296 | + |
| 297 | + |
78 | 298 | def init_transformer_weights(module: nn.Module, initializer_range: float) -> None: |
79 | 299 | """Initialize the weights""" |
80 | 300 | if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
0 commit comments