-
-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attention Layer Bottleneck #24
Comments
So, this is the actual implementation being used for axial attention: https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L153-L178 which seems like ti is doing both vertical and horizontal passes. But, I just realized that we don't actually do any position embeddings, other than the lat/lon inputs, before passing to the axial attention. So we might need to add that and see what happens? The number of heads is the number of heads for multi-headed attention. So we can probably just set it to one and be fine I think. |
Okay, how do we do this? If we have 8 channels in the RNN output with 28×28 height and width, is this embedding information of which pixel we are in? I am struggling a bit wrapping my head around attention and axial attention... Also when you say set the number of heads to 1 you mean for debugging, right? We still want multi head attention to replicate their model in the end. |
Yeah, set to 1 for the debugging to get it to overfit first. And yeah, the position embedding is saying which pixel we are in, and the location information of that pixel related to other pixels in the input. The library has a function for it, so we can probably just do this for where we use the axial attention: #25 |
I don't know if it's relevant but it recently occured to me that MetNet version 1 is quite similar to the Temporal Fusion Transformer (TFT) (also from Google!), except MetNet has 2 spatial dimensions, whilst TFT is for timeseries without any (explicit) spatial dimensions. In particular, both TFT and MetNet use an RNN followed by multi-head attention. In the TFT paper, they claim that the RNN generates a kind of learnt position encoding. So they don't bother with a "hand-crafted" position encoding. The TFT paper says:
|
I can confirm initial tests show promising results now, the networks seems to learn something now :) I'll be back with more results in a few days. |
@all-contributors please add @jacobbieker for code |
I've put up a pull request to add @jacobbieker! 🎉 |
@all-contributors please add @JackKelly for code |
I've put up a pull request to add @JackKelly! 🎉 |
@all-contributors please add @ValterFallenius for userTesting |
I've put up a pull request to add @ValterFallenius! 🎉 |
Found a bottleneck: the attention layer
I have found a potential bottleneck for why bug #22 occurred. It seems like the axial attention layer is some kind of bottleneck. I ran the network for 1000 epochs to try to overfit a small subset of 4 samples. See run at
WandB. The network is not able to drop the loss at all almost and does not overfit the data, it yields a very bad result, some kind of mean. See image below:
After removing the axial attention layer the model does as expected and overfits the training data, see below after 100 epochs:
The message from the author listed in #19 does mention that our implementation of axial attention seems to be very different from theirs, he says: "Our (Google's) heads were small MLPs as far as I remember (I'm not at google anymore so do not have access to the source code)." I am not experienced enough to look into the source code of our Axial Attention Library to see how this differs from theirs.
The text was updated successfully, but these errors were encountered: