From 86e77501b5c3c912f7df350c866d543eda854876 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 2 Jul 2023 16:12:43 +0100 Subject: [PATCH] Fix --- src/brevitas_examples/llm/llm_quant/mha_layers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index 3630b1e57..51c6d7933 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -13,6 +13,8 @@ def attention_mask_handler( attention_mask = attention_mask.repeat(batch_size, 1, 1, 1) if attention_mask.shape[1] == 1: attention_mask = attention_mask.repeat(1, num_heads, 1, 1) + if attention_mask.shape[2] == 1: + attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1) attention_mask = attention_mask.view( batch_size * num_heads, query_seq_length, key_value_seq_length) elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1: