diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2707c6eeab2..f9bbbca0b38 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -204,7 +204,7 @@ def forward( loss = None if self.use_cache: if past_key_values is not None: - input_ids = input_ids[:, -1:] + input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids # Flatten the past_key_values (no need to flatten for models using multi-query attn) if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: past_key_values = tuple(