You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm struggling to import the weights from torchvision's ViT to ours. The problem is that the correct map of the attention layers in torch to the one in metalhead seems non-trivial.
using PythonCall, Metalhead
torch =pyimport("torch")
functionth2jl(x::Py)
xj =pyconvert(Array, x.detach().numpy())
xj =permutedims(xj, ndims(xj):-1:1)
return xj
end
m = torch.nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=true, bias=false, add_bias_kv=false)
# python forward pass
x = torch.randn(1, 3, 2)
y, a =m(x, x, x, need_weights=true)
mj = Metalhead.MHAttention(2, 1, qkv_bias=false)
# copy weights
mj.qkv_layer.weight .=th2jl(m.in_proj_weight)'# transpose back since Linear layers in pytorch don't need transpose
mj.projection.layers[1].weight .=th2jl(m.out_proj.weight)'# julia forward pass
xj =th2jl(x)
yj =mj(xj)
@assert yj ≈th2jl(y) # false
Probably this is due to the permutations and chunking in our initial projection, possibly we should rearrange them in such a way that the natural weight mapping from pytorch just works.
The attention layers need a lot of TLC. They were written before a lot of functionality landed in upstream libraries such as FluxML/NNlib.jl#455, and so are presumably not only slower but also doing way more things than they need to. This is one aspect that someone can take up and re-write. Given that NNlib has the functionality we need, now only a couple of questions need to be answered:
Does NNlib's attention also have an NNlibCUDA equivalent giving us good performance on GPUs?
Metalhead should probably not use the attention function directly and use something like TensorCast along with the extended batched_mul to write the MHAttention layer. Is this GPU friendly, AD friendly and performant enough? If not, then we can always fall back to the NNlib version but this prevents Metalhead from adding its own goodies
I'm struggling to import the weights from torchvision's ViT to ours. The problem is that the correct map of the attention layers in torch to the one in metalhead seems non-trivial.
Probably this is due to the permutations and chunking in our initial projection, possibly we should rearrange them in such a way that the natural weight mapping from pytorch just works.
Pinging @theabhirath for more insights.
The text was updated successfully, but these errors were encountered: