The convolutional neural network has been the mainstay of deep learning vision approaches for decades, dating back to the work of LeCun and colleagues in 1989. In the that work, it was proposed that restrictions on model capacity would be necessary to prevent over-parametrized fully connected models from failing to generalize and that these restrictions (translation invariance, local weight sharing etc.) could be encoded into the model itself.
Convolution-based neural networks have since become the predominant deep learning method for image classification and generation due to their parsimonius weight sharing (allowing for larger inputs to be modeled with fewer parameters than traditional fully connected models), their flexibility (as with proper pooling after convolutional layers a single model may be applied to images of many sizes), and above all their efficacy (nearly every state-of-the-art vision model since the early 90s has been based on convolutions).
It is interesting to note therefore that one of the primary motivations of the use of convolutions, that over-parametrized models must be restricted in order to avoid overfitting, has since been found to not apply to deep learning models. Over-parametrixed fully connected models do not tend to overfit image data even if they are capable of doing so, and furthermore convolutional models that are currently applied to classify (quite accurately too) large image dataasets are capable of fitting pure noise (ref).
Therefore it is reasonable to hypothesize that the convolutional architecture, although effective and flexible, is by no means required for accurate image classification or other vision tasks. One particularly effective approach has been translated from the field of natural language processing that has been termed the 'transformer', which makes use of self-attention mechanisms. We also consider mlp-based mixer architectures that do not make use of attention.
We focus on the ViT B 32 model introduced by Dosovitsky and colleagues. This model is based on the original transformer from Vaswani and colleages, in which self-attention modules previously applied with recurrent neural networks were instead applied to patched and positionally-encoded sequences in series with simple fully connected architectures.
The transformer architecture was found to be effective for natural language processing tasks and was subsequently employed in vision tasks after convolutional layers. But the work introducing the ViT went further and applied the transformer architecture directly to patches of images, which has been claimed to occur without explicit convolutions (but does in fact used strided convolution to form patch embeddings as we will later see). It is an open question of how similar these models are to convolutional neural networks.
The transformer is a feedforward neural network model that adapted a concept called 'self-attention' from recurrent neural networks that were developed previously. Attention modules attempted to overcome the tendancy of recurrent neural networks to be overly influenced by input elements that directly preceed the current element and 'forget' those that came long before. The original transformer innovated by applying attention to tokens (usually word embeddings) followed by an MLP, foregoing the time-inefficiencies associated with recurrent neural nets.
In the original self-attention module, each input (usually an embedding of a word) is associated with three vectors
before constant scaling followed by a softmax transformation to the vector
The theoretical basis behind the attention module is that certain tokens (originally word embeddings) should 'pay attention' to certain other tokens moreso than average, and that this relationship should be learned directly by the model. For example, given the sentence 'The dog felt animosity towards the cat, so he behaved poorly towards it' it is clear that the word 'it' should be closely associated with the word 'cat', and the attention module's goal is to model such associations.
When we reflect on the separate mathematical operations of attention, it is clear that they do indeed capture something that may be accurately described by the English word. In the first step of attention, the production of
But this clean theoretical justification breaks down when one considers that models with single attention modules generally do not perform well on their own but require many attention modules in parallel (termed multi-head attention) and in series. Given a multi-head attention, one might consider each separate attention value to be context-specific, but it is unclear why then attention should be used at all given that an MLP alone may be thought of as providing context-specific attention. Transformer-based models are furthermore typically many layers deep, and it is unclear what the attention value of an attention value of a token actually means.
Nevertheless, to gain familiarity with this model we note that for multi-head attention, multiple self-attention
A single transformer encoder applied to image data may be depicted as follows:
For a more thorough introduction to the transformer, see Alammar's Illustrated Transformer. See here for an example of a transformer encoder architecture applied to a character sequence classification task.
One way to understand how a model yields its outputs is to observe the inputs that one can generate using the information present in the model itself. This may be accomplished by picking an output class of interest (here the models have been trained on ImageNet so we can choose any of the 1000 classes present there), assigning a large constant to the output of that class and then performing gradient descent on the input, minimizing the difference of the model's output given initial random noise and that large constant for the specified output index.
We add two modifications to the technique denoted in the last paragraph: a 3x3 Gaussian convolution (starting with
More precisely, the image generation process is as follows: given a trained model
next we find the gradient of the absolute value of the difference between some large constant
and finally input is updated by gradient descent followed by Gaussian convolution
Positional jitter is applied between updates such that the subset of the input
One of the first differences of note compared to the inputs generated from convolutional models is the lower resolution of the generated images: this is partly due to the inability of vision transformer base 32 (ViT B 32) to pool outputs before the classification step such that all model inputs must be of dimension
When we observe representative images of a subset of ImageNet animal classes with Vit B 32,
as well as landscapes and inanimate objects with the same model,
it is clear that recognizable images may be formed using only the information present in the vision transformer architecture just as was accomplished for convolutional models.
Vision Transformer hidden layer representation overview
Another way to understand a model is to observe the extent to which various hidden layers in that model are able to autoencode an input: the better the autoencoding, the more complete the information in that layer.
First, let's examine an image that is representative of one of the 1000 categories in the ImageNet 1k dataset: a dalmatian. We can obtain various layer outputs by subclassing the pytorch nn.Module
class and accessing the original vision transformer model as self.model
. For ViT models, the desired transformer encoder layer outputs may be accessed as follows:
class NewVit(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x: torch.Tensor, layer: int):
# Reshape and permute the input tensor
x = self.model._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.model.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
for i in range(layer):
x = self.model.encoder.layers[i](x)
return x
where layer
indicates the layer of the output desired (self.model.encoder.layers
are 0-indexed so the numbering is accurate). The NewVit
class can then be instantiated as
vision_transformer = torchvision.models.vit_b_32(weights='IMAGENET1K_V1')
new_vision = NewVit(vision_transformer).to(device)
new_vision.eval()
First let's observe the effect of layer depth (specifically transformer encoder layer depth) on the representation accuracy of an untrained ViT_B_32 and compare this to what was observed for ResNet50 (which appears to be a fairly good stand-in for other convolutional models).
The first notable aspect of input representation is that it appears to be much more difficult to approximate a natural image using Gaussian-convolved representations for ViT than for ResNet50, or more precisely it is difficult to find a learning rate
Compare the decreasing representation clarity with increased depth to the nearly constant degree of clarity in the ViT: even at the twelth and final encoder, the representation quality is approximately the same as that in the first layer. The reason as to why this is the case is explored in the next section.
As each encoder is the same size, we can also observe the representation of only that encoder (rather than the given encoder and all those preceeding it).
When we consider the representation of the first layer of the vision transformer compared to the first layer in ResNet50, it is apparent that the former has a less accurate representation. Before this first layer, ViT has an input processing step in which the input is encoded as a sequence of tokens, which occurs by forming 768 convolutional filters each 3x32x32 (hence the name ViT B 32) large, with 32-size strides. In effect, this means that the model takes 32x32 patches (3 colors each) of the original input and encodes 768 different 3x7x7 arrays which act analagously to the word embeddings used in the original transformer.
It may be wondered then if it is the input processing step via strided 32x32 convolutions or the first encoder layer that is responsible for the decrease in representation accuracy. Generating representation of the outputs of first the input processing convolution and then the input processing followed by the first encoder layer of an initial image of a tesla coil, it is clear that the input processing itself is responsible for the decreased representation clarity, and furthermore that training greatly enhances the processing convolutional layer's representation resolution (although still not to the degree seen in the first convolution of ResNet)
For ResNet50, an increase in representation resolution for the first convolutional layer upon training is observed to coincide with the appearance of Gabor function wavelets in the weights of that layer (see the last supplementary figure of this paper). It may be wondered if the same effect of training is observed for these strided convolutions, and so we plot the normalized (minimum set to 0, maximum set to 1, and all other weights assigned accordingly) weights before and after training to find out.
In some convolutions we do indeed see wavelets (of various frequencies too) but in other we see something curious: no discernable pattern at all is visible in the weights of around half of the input convolutional filters. As seen in the paper referenced in the last paragraph, this is not at all what is seen for ResNet50's first convolutional layer, where every convolutional filter plotted has a markedly non-random weight distribution (most are wavelets).
Earlier it was observed that for Vit B 32 the process of training led to the appearance of wavelet patterns in the input convolution layer and a concomitant increase in representational accuracy. For that model the convolutional operation is not overcomplete, but for the ViT Large 16 model it is. It can therefore be hypothesized that training is not necessary for accurate input representation for the procesing convolution of ViT L 16, and indeed this is found to be the case.
Note the lack of consistent wavelet weight patterns in the input convolution, even after training (and even after extensive pretraining on weakly supervised). This observation may explain why Xaio and colleagues found that replacing the strided input processing convolutions above with 4 layers of 3x3 convolutions (followed by one 1x1 layer) improves vision transformer training stability and convergence as well as ImageNet test accuracy.
Being that the input convolutional stem to the smaller ViT models is not capable of accurately representing an input, to understand later layers' representation capability we can substitute a trained input processing convolutional layer from a trained vision transformer, and chain this layer to the rest of a model from an untrained ViT.
class NewVit(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x: torch.Tensor):
# Apply a trained input convolution
x = trained_vit._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.model.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
for i in range(1):
x = self.model.encoder.layers[i](x)
vision_transformer = torchvision.models.vit_h_14().to(device) # untrained model
trained_vit = torchvision.models.vit_h_14(weights='DEFAULT').to(device) # weights='IMAGENET1K_V1'
trained_vit.eval()
new_vision = NewVit(vision_transformer).to(device)
new
When various layer representations are generated for ViT Base 32, it is clear that although there is a decrease in representation accuracy as depth increases in the encoder stack with fixed (
It is apparent that all encoder layers have imperfections in certain patches, making them less accurate than the input convolution layer's representation. It may be wondered why this is, being that the encoder layers have outputs of dimension
Vision transformer models apply positional encodings to the tokens after the input convolution in the first transformer encoder block, and notably this positional encoding is itself trained: initialized as a normal distribution nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))
, this positional encoding tensor (torch.nn.Parameter
objects are torch.Tensor
subclasses) is back-propegated through such that the tensor elements are modified during training. It may be wondered if a change in positional encoding parameters is responsible for the change in first encoder layer representational accuracy. This can be easily tested: re-assigning an untrained vision transformer's positional embedding parameters to that of a trained model's positional embedding parameters may be accomplished as follows:
vision_transformer = torchvision.models.vit_b_32(weights='DEFAULT').to(device) # 'IMAGENET1K_V1'
vision_transformer.eval()
untrained_vision = torchvision.models.vit_b_32().to(device)
untrained_vision.encoder.pos_embedding = vision_transformer.encoder.pos_embedding
When this is done, however, there is no noticeable difference in the representation quality. Repeating the above procedure for other components of the first transformer encoder, we find that substituting the first layer norm (the one that is applied before the self-attention module) of the first encoder for a trained version is capable of increasing the representational quality substantially.
In the last section it was observed that changing the parameters (specifically swapping untrained parameters for trained ones) of a layer normalization operation led to an increase in representational accuracy. Later on this page we will see that layer normalization tends to decrease representational accuracy, and so we will stop to consider what exactly this transformation entails.
Given a layer's features indexed by
where
Short consideration of the above formula should be enough to convince one that layer normalization is in general non-invertible such that many different possible inputs
Thus it should come as no surprise that the vision transformer's representations of the input often form noticeably patchwork-like images when layer normalization is applied (to each patch separately). The values
We will now switch to larger vision transformers, mostly because these are the ones that performed well on ImageNet and other similar benchmarks. We can use a different image of a Tesla coil and apply this input to a ViT Large 16 model. This model accepts inputs of size 512x512 rather than the 224x224 used above and makes patches of that input of size 16x16 such that there are
Transformer encoders contain a number of operations: layer normalization, self-attention, feedforward fully connected neural networks, and residual addition connections. With the observation that removing layer normalization yields more accurate input representations from encoders before training in small vision transformers, it may be wondered what exactly in the transformer encoder module is necessary for representing an input, or equivalently what exactly in this module is capable of storing useful information about the input.
Recall the architecture of the vision transformer encoder module:
We are now going to focus on ViT Large 16, where the 'trained' modules are pretrained on weakly supervised datasets before being trained on ImageNet 1K, and images are all 3x512x512.
The first thing to note is that this model behaves similarly to ViT Base 32 with respect to the input convolution: applying a trained input convolution without switching to a trained first layernorm leads to patches of high-frequency signal in the input generation, which can be ameliorated by swapping to a trained layernorm.
The approach we will follow is an ablation survey: each component will be removed one after the other in order to observe which ones are required for input representation from the module output. Change are made sub-classing the EncoderBlock
module of ViT and then simply removing the relevant portions.
This class is originally as follows:
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
...
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
to remove residual connections, we remove the tensor addition steps as follows:
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
...
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x = self.dropout(x)
y = self.ln_2(x)
y = self.mlp(input)
return y
Then we can replace the EncoderBlock
modules in the vision transformer with our new class containing no residual connections as follows:
# 24 encoder modules per vision transformer
for i in range(24):
vision_transformer.encoder.layers[i] = EncoderBlock(16, 1024, 4096, 0., 0.)
Similar changes can be made to remove layer normalizations, MLPs, or self-attention modules. One problem remains, and that is how to load trained model weights into our modified transformer encoders. As we have re-built the encoders to match the architecture of the original, however, and as the residual connections contain no trainable parameters we can simply load the original trained model and replace each layer with the trained version of that layer. For example, modifying the ViT L 16 to discard residual connections before adding weighted MLP layers we have
for i in range(24):
vision_transformer.encoder.layers[i] = EncoderBlock(16, 1024, 4096, 0., 0.)
original_vision_transformer = torchvision.models.vit_l_16(weights='IMAGENET1K_SWAG_E2E_V1').to(device)
original_vision_transformer.eval()
for i in range(24):
vision_transformer.encoder.layers[i].mlp = original_vision_transformer.encoder.layers[i].mlp
For some choice encoder representations from an untrained ViT L 16 (with a trained input convolutional stem to allow for 512x512 inputs), we have the following input representations
For the last layer, the same number of iterations yields
It is clear that removal of either self-attention or (both) layer normalizations in each encoder module, but not the MLPs from each encoder module is sufficient to prevent the vast majority of the decrease in representation quality with increased depth up to encodrer layer 12, whereas at the deepest layer (24) of this vision transformer we can see that removal of attention modules from each encoder (but not MLPs or LayerNorms alone) prevent most of the decline in input representation accuracy.
This is somewhat surprising given that the MLP used in the transformer encoder architecture is itself typically non-invertible: standard practice implemented in vision transformers is to have the MLP implemented as a one-hidden-layer (with input and output dimensions equal to the dimension of the self-attention hidden layer
Likewise, self-attention and layernorm transformations are both non-invertible and yet removal of only one or the other appears sufficient for substantially improving input representation accuracy in early layers. From the rate of minimization of the embedding distance
it is clear that self-attention layers are extremely poorly conditioned: including these makes satisfying
(where
Removal of layer normalization transformations does not yield as much of an increase in input representation accuracy for the last layer (24) of the ViT large 16. This is because without normalization, at that depth the gradient begins to explode for certain patches: observe the high-frequency signal originating from two patches near the center of the layernormless example above. A very small gradient update rate
Now we will examine the effects of residual connections on input representation in the context of vision transformers. After removing all residual connections from each transformer encoder layer, we have
It is apparent from these results that self-attention transformations are quite incapable of transmitting most input information, which is why removing all attention modules from encoders 1 through 4 results in a recovery of the ability of the output of encoder 4 to accurately represent the input.
It may be observed that even a single self-attention layer requires an enormous number of gradient descent iterations to achieve a marginally accurate representation for a trained model, and that even this is insufficient for an untrained one. This is evidence for approximate non-invertibility, which may equivalently be viewed as poor conditioning in the forward transformation.
It is interesting to observe that training leads to a somewhat more informative (with respect to the input) self-attention module considering that attention layers from trained models carry very little input information with respect to fully connected layers. There may be some benefit for the transformer encoder to transmit more information during the learning procedure, but it is unclear why this is the case.
What about true non-invertibility? If our gradient descent procedure on the input is effective, then an unbounded increase in the number of iterations of gradient descent would be expected to result in an asymptotically zero distance between the target layer output
If this is the case, we can easily find evidence that points to non-invertibility as a cause for the poor input representation for attention layers. For an invertible transformation
On the other hand, if we find that the embedding distance heads towards the origin
but at the input distance does not head towards the origin
then our representation procedure cannot distinguish between the multiple inputs that may yield one output. And indeed, this is found to be the case for the representation of the tesla coil above for one encoder module with layer normalization followed by multi-head attention alone.
Thus it is apparent that self-attention layers are generally incapable of accurately representing an input due to non-invertibility as well as poor conditioning if residual connections are removed.
It is not particiularly surprising that self-attention should be non-invertible first and foremost because the transformations present in the self-attention layer (more specifically the multi-head attention layer) are together non-invertible in the general case. Recall that the first step of attention (after forming
With residual connections included and assuming constraints on the self-attention transformation's Lipschitz constants (and assuming the presence of residual connections) as observed by Zha and colleages. That said, it is apparent from the experiments above that the vision transformer's attention modules are indeed invertible to at least some degree without modification if residuals are allowed (especially if layernorms are removed).
It may be wondered how much attention layers contribute to a trained model's input representation ability to transform the input to match the manifolds learned during training. Removing self-attention layers or MLPs from each transformer encoder, we see that there is some small highlighting of certain areas of the Tesla coil that exist in the deeper representations without attention layers but not without MLPs.
The same observations are made for an input image of a Dalmatian, which is one of the 1000 classes that the vision transformer is trained upon.
It may be appreciated that training results in a substantial increase in the ability of attention modules (in the context of ViT Large 16 without MLP layers) to represent an input, but at the same time it is apparent that each attention layer severely limits the information that passes through to the output.
As vision transformers are effective regardless of this severely limited information pass, it may be wondered whether this matters to the general goal of image classification. For the training process, there certainly is a significant effect of such information restriction: if one removes the residual connections from ViTs, the resulting model is very difficult to train and fails to learn even modest datasets such as CIFAR-10.
On the other hand, if remove residual connections from MLP-mixers (which are more or less identical to transformers except that the self-attention layer has been swapped for a transposed feed-forward one) the resulting model is not difficult to optimize, and indeed has only a limited decrease in accuracy.
The process of layer representation visualization relies on gradient descent to modify an initially random input such that the layer in question cannot 'distinguish' between the modified input
This is not necessarily a problem for the problem of classification: far from it, we typically want many inputs
On the other hand, poor convergence of the gradient descent procedure is likely to indicate difficulty training. Suppose that it takes a very large number of iterations
(where
Now consider the learning procedure of stochastic gradient descent in which the model's parameters
It can be recognized that these are closely related optimization problems. In particular, observe how the ability to minimize
After the successes of vision transformers, Tolstikhin and colleagues and independently Melas-Kyriazi investigated whether or not self-attention is necessary for the efficacy of vision transformers. Somewhat surprisingly, the answer from both groups is no: replacing the attention layer with a fully connected layer leads to a minimal decline in model performance, but requires significantly less compute than the tranditional transformer model. When compute is constant, Tolstikhin and colleagues find that there is little difference or even a slight advantage to the attentionless models, and Melas-Kyriazi finds that conversely using only attention results in very poor performance.
The models investigated have the same encoder stacks present in the Vision transformers, but each encoder stack contains two fully connected layers. The first fully connected layer is applied to the features of each patch, and for example if the hidden dimension of each patch were 512 then that is the dimension of each parallel layer's input. The second layer is applied over the patch tokens (such that the dimension of each MLP in that layer's input is the number of tokens the model has).These models were referred to as 'MLP-Mixers' by the Tolstikhin group, which included this helpful graphical sumary of the architecture:
There is a notable difference between the mixer architecture and the Vision Transformer: each encoder block in the ViT places the attention layer first and follows this by the MLP layer, whereas each block in the attentionless mixer architecture has the feature MLP first and mixer second.
We investigate the ability of various layers of an untrained MLP mixer to represent an input image. We employ an architecture with a patch size of 16x16 to a 224x224x3 input image (such that there are
Somewhat surprisingly, this is not found to be the case: the first encoder layer for the above model is not particularly accurate at autoencoding its input, and the autencoding's accuracy declines the deeper the layer in question. Specifying a smaller patch size of 8 does reduce the representation error.
It may be wondered why the representation of each encoder layer for the 16-sized patch model is poor, being that each transformer encoder in the model is overcomplete with respect to the input.
This poor representation is must therefore be (mostly) due to approximate non-invertibility (due to poor conditioning), and this is bourne out in practice as the distance of the model output with generated input
is empirically difficult to reduce beyond a certain amount. By tinkering with the mlp encoder modules, we find that this is mostly due to the presence of layer normalization: removing this transformation (from every MLP) removes the empirical difficulty of minimizing
than for the Melas-Kryiazi implementation, which is shown in the following figure.
After training on ImageNet, we find that the input representations are broadly similar to those found for ViT base models, with perhaps some modest increase in clarity (ie accuracy compared to the original) in the deeper layers.
For further investigation into the features of MLP-mixers and vision transformers, see this page.