Skip to content

Commit

Permalink
Fix SwinLayer / DonutSwinLayer / ClapAudioLayer attention mask device (
Browse files Browse the repository at this point in the history
…huggingface#31295)

Fix DonutSwinLayer attention mask device
  • Loading branch information
gorodnitskiy authored Jun 6, 2024
1 parent b6c9f47 commit 3b4d3d0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
10 changes: 5 additions & 5 deletions src/transformers/models/clap/modeling_clap.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,10 @@ def set_shift_and_window_size(self, input_resolution):
self.shift_size = 0
self.window_size = min(input_resolution)

def get_attn_mask(self, height, width, dtype):
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
Expand Down Expand Up @@ -661,9 +661,9 @@ def forward(
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attn_mask = self.get_attn_mask(
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
)

attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,10 @@ def set_shift_and_window_size(self, input_resolution):
self.shift_size = 0
self.window_size = min(input_resolution)

def get_attn_mask(self, height, width, dtype):
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
Expand Down Expand Up @@ -633,9 +633,9 @@ def forward(
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attn_mask = self.get_attn_mask(
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
)

attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,10 @@ def set_shift_and_window_size(self, input_resolution):
self.shift_size = 0
self.window_size = min(input_resolution)

def get_attn_mask(self, height, width, dtype):
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
Expand Down Expand Up @@ -710,9 +710,9 @@ def forward(
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
attn_mask = self.get_attn_mask(
height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
)

attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
Expand Down

0 comments on commit 3b4d3d0

Please sign in to comment.