-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shift features implementation #1
Comments
I think 4 different roll will need a |
Thank you for the suggestion. We simply adopt the shift implementation from TSM. I agree that the def shift_feat_memcopy(x):
B, C, H, W = x.shape
g = C // 12
out = torch.zeros_like(x)
out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left
out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right
out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up
out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down
out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift
return out
def shift_feat_roll(x):
B, C, H, W = x.shape
g = C // 12
x0 = torch.roll(x[:, g * 0:g * 1], 1, 3)
x1 = torch.roll(x[:, g * 1:g * 2], -1, 3)
x2 = torch.roll(x[:, g * 2:g * 3], 1, 2)
x3 = torch.roll(x[:, g * 3:g * 4], -1, 2)
out = torch.cat((x0, x1, x2, x3, x[:, g*4:]), dim=1)
return out |
I've run the tests above myself and it looks like you're correct and it's indeed slower, I was wrong. Closing the issue :) |
Hi, very interesting research. I wonder why did you implement the
shift_feature
as memory copySPACH/models/shiftvit.py
Line 107 in 497c1d8
instead of using
Tensor.roll
operation? It would make your block much faster. Another benefit would be that pixels from one side wouldleak
to the other giving the network to pass information from one boundary to another, which seems a better option that dublication of the last row during each shift.The text was updated successfully, but these errors were encountered: