diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index ff02502f79a..a345bd18b79 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -534,7 +534,7 @@ def compute_past_key_values_output_shapes( ) -> Dict[str, int]: batch_size = input_ids.size(0) num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads + embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads sequence_length = input_ids.size(1) encoder_sequence_length = encoder_hidden_states.size(1)