diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 609729a4..53d02ee6 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -202,6 +202,7 @@ def extract_hiddens( assert hasattr(model, "get_encoder") and callable(model.get_encoder) model = assert_type(PreTrainedModel, model.get_encoder()) is_enc_dec = False + print("Is the model encoder-decoder:", is_enc_dec) has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states) if has_lm_preds and rank == 0: @@ -312,8 +313,12 @@ def extract_hiddens( outputs = model(**inputs, output_hidden_states=True, use_cache=True) - cached_question_kv = get_reusable_kv( - answer_ids=answer, past_key_values=outputs.past_key_values + cached_question_kv = ( + get_reusable_kv( + answer_ids=answer, past_key_values=outputs.past_key_values + ) + if cached_question_kv is None + else cached_question_kv ) # Compute the log probability of the answer tokens if available if has_lm_preds: