You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
from axial_attention import AxialAttention
img = torch.randn(1, 3, 256, 256)
attn = AxialAttention(
dim = 3, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied
heads = 1, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn(img) # (1, 3, 256, 256)
Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?
The text was updated successfully, but these errors were encountered:
Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?
The text was updated successfully, but these errors were encountered: