Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shift features implementation #1

Closed
bonlime opened this issue Jan 27, 2022 · 3 comments
Closed

Shift features implementation #1

bonlime opened this issue Jan 27, 2022 · 3 comments

Comments

@bonlime
Copy link

bonlime commented Jan 27, 2022

Hi, very interesting research. I wonder why did you implement the shift_feature as memory copy

def shift_feat(x, n_div):

instead of using Tensor.roll operation? It would make your block much faster. Another benefit would be that pixels from one side would leak to the other giving the network to pass information from one boundary to another, which seems a better option that dublication of the last row during each shift.

@YouJiacheng
Copy link

I think 4 different roll will need a torch.stack or torch.cat(indexing, roll and then combine), maybe not faster.

@wgting96
Copy link
Contributor

Thank you for the suggestion. We simply adopt the shift implementation from TSM.

I agree that the roll operation enables the richer information aggregation. However, in my understanding, the roll operation only support one shift direction. We need to split, and roll the input feature by 4 times, which makes this operation inefficient. I benchmark the following code snippet, which shows that the roll operation is about 20% relatively slower than the memory copy.

def shift_feat_memcopy(x):
    B, C, H, W = x.shape
    g = C // 12
    out = torch.zeros_like(x)

    out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:]  # shift left
    out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1]  # shift right
    out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :]  # shift up
    out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :]  # shift down

    out[:, g * 4:, :, :] = x[:, g * 4:, :, :]  # no shift
    return out


def shift_feat_roll(x):
    B, C, H, W = x.shape
    g = C // 12
    x0 = torch.roll(x[:, g * 0:g * 1], 1, 3)
    x1 = torch.roll(x[:, g * 1:g * 2], -1, 3)
    x2 = torch.roll(x[:, g * 2:g * 3], 1, 2)
    x3 = torch.roll(x[:, g * 3:g * 4], -1, 2)

    out = torch.cat((x0, x1, x2, x3, x[:, g*4:]), dim=1)
    return out

@bonlime
Copy link
Author

bonlime commented Feb 22, 2022

I've run the tests above myself and it looks like you're correct and it's indeed slower, I was wrong. Closing the issue :)

@bonlime bonlime closed this as completed Feb 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants