From 3b4d3d09fd8cee2b4cc2fdd7c12ea51ca147c6cc Mon Sep 17 00:00:00 2001 From: Alex Gorodnitskiy Date: Thu, 6 Jun 2024 21:52:14 +0100 Subject: [PATCH] Fix SwinLayer / DonutSwinLayer / ClapAudioLayer attention mask device (#31295) Fix DonutSwinLayer attention mask device --- src/transformers/models/clap/modeling_clap.py | 10 +++++----- src/transformers/models/donut/modeling_donut_swin.py | 10 +++++----- src/transformers/models/swin/modeling_swin.py | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index d97d36c154badc..1c236d29d4e734 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -593,10 +593,10 @@ def set_shift_and_window_size(self, input_resolution): self.shift_size = 0 self.window_size = min(input_resolution) - def get_attn_mask(self, height, width, dtype): + def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1), dtype=dtype) + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), @@ -661,9 +661,9 @@ def forward( # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) - if attn_mask is not None: - attn_mask = attn_mask.to(hidden_states_windows.device) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) attention_outputs = self.attention( hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 4775d00c19e142..7e899f453f1c0f 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -565,10 +565,10 @@ def set_shift_and_window_size(self, input_resolution): self.shift_size = 0 self.window_size = min(input_resolution) - def get_attn_mask(self, height, width, dtype): + def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1), dtype=dtype) + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), @@ -633,9 +633,9 @@ def forward( # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) - if attn_mask is not None: - attn_mask = attn_mask.to(hidden_states_windows.device) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) attention_outputs = self.attention( hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 13c20f59e99dfa..f3f2dedeb6f3dd 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -642,10 +642,10 @@ def set_shift_and_window_size(self, input_resolution): self.shift_size = 0 self.window_size = min(input_resolution) - def get_attn_mask(self, height, width, dtype): + def get_attn_mask(self, height, width, dtype, device): if self.shift_size > 0: # calculate attention mask for SW-MSA - img_mask = torch.zeros((1, height, width, 1), dtype=dtype) + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) height_slices = ( slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), @@ -710,9 +710,9 @@ def forward( # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype) - if attn_mask is not None: - attn_mask = attn_mask.to(hidden_states_windows.device) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) attention_outputs = self.attention( hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions