Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention Layer Bottleneck #24

Open
ValterFallenius opened this issue Mar 15, 2022 · 11 comments
Open

Attention Layer Bottleneck #24

ValterFallenius opened this issue Mar 15, 2022 · 11 comments
Labels
bug Something isn't working

Comments

@ValterFallenius
Copy link

ValterFallenius commented Mar 15, 2022

Found a bottleneck: the attention layer
I have found a potential bottleneck for why bug #22 occurred. It seems like the axial attention layer is some kind of bottleneck. I ran the network for 1000 epochs to try to overfit a small subset of 4 samples. See run at
WandB. The network is not able to drop the loss at all almost and does not overfit the data, it yields a very bad result, some kind of mean. See image below:

yhat against y

After removing the axial attention layer the model does as expected and overfits the training data, see below after 100 epochs:

yhat against y without attention

The message from the author listed in #19 does mention that our implementation of axial attention seems to be very different from theirs, he says: "Our (Google's) heads were small MLPs as far as I remember (I'm not at google anymore so do not have access to the source code)." I am not experienced enough to look into the source code of our Axial Attention Library to see how this differs from theirs.

  1. What are heads in the axial attention? What does the number of heads have to do with anything?
  2. Are we doing both vertical and horizontal attention passes in our implementation?
@ValterFallenius ValterFallenius added the bug Something isn't working label Mar 15, 2022
@jacobbieker
Copy link
Member

So, this is the actual implementation being used for axial attention: https://github.com/lucidrains/axial-attention/blob/eff2c10c2e76c735a70a6b995b571213adffbbb7/axial_attention/axial_attention.py#L153-L178 which seems like ti is doing both vertical and horizontal passes. But, I just realized that we don't actually do any position embeddings, other than the lat/lon inputs, before passing to the axial attention. So we might need to add that and see what happens? The number of heads is the number of heads for multi-headed attention. So we can probably just set it to one and be fine I think.

@ValterFallenius
Copy link
Author

Okay, how do we do this?

If we have 8 channels in the RNN output with 28×28 height and width, is this embedding information of which pixel we are in? I am struggling a bit wrapping my head around attention and axial attention...

Also when you say set the number of heads to 1 you mean for debugging, right? We still want multi head attention to replicate their model in the end.

@jacobbieker
Copy link
Member

Yeah, set to 1 for the debugging to get it to overfit first. And yeah, the position embedding is saying which pixel we are in, and the location information of that pixel related to other pixels in the input. The library has a function for it, so we can probably just do this for where we use the axial attention: #25

@JackKelly
Copy link
Member

JackKelly commented Mar 15, 2022

I just realized that we don't actually do any position embeddings, other than the lat/lon inputs, before passing to the axial attention. So we might need to add that and see what happens?

I don't know if it's relevant but it recently occured to me that MetNet version 1 is quite similar to the Temporal Fusion Transformer (TFT) (also from Google!), except MetNet has 2 spatial dimensions, whilst TFT is for timeseries without any (explicit) spatial dimensions. In particular, both TFT and MetNet use an RNN followed by multi-head attention. In the TFT paper, they claim that the RNN generates a kind of learnt position encoding. So they don't bother with a "hand-crafted" position encoding.

The TFT paper says:

[The LSTM] also serves as a replacement for standard positional encoding, providing an appropriate inductive bias for the time ordering of the inputs.

@ValterFallenius
Copy link
Author

I can confirm initial tests show promising results now, the networks seems to learn something now :) I'll be back with more results in a few days.

@peterdudfield
Copy link
Contributor

@all-contributors please add @jacobbieker for code

@allcontributors
Copy link
Contributor

@peterdudfield

I've put up a pull request to add @jacobbieker! 🎉

@peterdudfield
Copy link
Contributor

@all-contributors please add @JackKelly for code

@allcontributors
Copy link
Contributor

@peterdudfield

I've put up a pull request to add @JackKelly! 🎉

@peterdudfield
Copy link
Contributor

@all-contributors please add @ValterFallenius for userTesting

@allcontributors
Copy link
Contributor

@peterdudfield

I've put up a pull request to add @ValterFallenius! 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants