Skip to content

Commit

Permalink
fix: fix the incompatibility of swintransformerv2 in ms2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh committed Jul 31, 2023
1 parent 308825d commit 4200e11
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions mindcv/models/swintransformerv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,34 +113,33 @@ def __init__(
self.cpb_act1 = nn.ReLU()
self.cpb_mlp2 = nn.Dense(512, num_heads, has_bias=False)

relative_coords_h = Tensor(np.arange(-(self.window_size[0] - 1), self.window_size[0]), mstype.float32)
relative_coords_w = Tensor(np.arange(-(self.window_size[1] - 1), self.window_size[1]), mstype.float32)
relative_coords_table = ops.stack(ops.meshgrid((relative_coords_h, relative_coords_w), indexing="ij"), axis=0)
relative_coords_table = relative_coords_table.transpose(1, 2, 0)
relative_coords_table = ops.expand_dims(relative_coords_table, axis=0)
relative_coords_h = np.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=float)
relative_coords_w = np.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=float)
relative_coords_table = np.stack(np.meshgrid(relative_coords_h, relative_coords_w, indexing="ij"), axis=0)
relative_coords_table = np.transpose(relative_coords_table, (1, 2, 0))
relative_coords_table = np.expand_dims(relative_coords_table, axis=0)
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
else:
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
relative_coords_table *= 8 # normalize to -8, 8
sign = ops.Sign()
relative_coords_table = (
sign(relative_coords_table) * Tensor(np.log2(np.abs(relative_coords_table.asnumpy()) + 1)) / np.log2(8)
np.sign(relative_coords_table) * np.log2(np.abs(relative_coords_table) + 1) / np.log2(8)
)

self.relative_coords_table = Parameter(
Tensor(relative_coords_table, mstype.float32), requires_grad=False
)

# get pair-wise relative position index for each token inside the window
coords_h = Tensor(np.arange(window_size[0]), mstype.int32)
coords_w = Tensor(np.arange(window_size[1]), mstype.int32)
coords = ops.stack(ops.meshgrid((coords_h, coords_w), indexing="ij"), axis=0) # 2, Wh, Ww
coords_flatten = ops.flatten(coords) # 2, Wh*Ww
coords_h = np.arange(window_size[0])
coords_w = np.arange(window_size[1])
coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij"), axis=0) # 2, Wh, Ww
coords_flatten = coords.reshape(coords.shape[0], -1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.transpose(1, 2, 0).asnumpy() # Wh*Ww, Wh*Ww, 2
relative_coords = np.transpose(relative_coords, (1, 2, 0)) # Wh*Ww, Wh*Ww, 2

relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
Expand Down

0 comments on commit 4200e11

Please sign in to comment.