diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index fbb93763..d4e6f4ea 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -283,7 +283,9 @@ def extract_hiddens( # Record the EXACT question we fed to the model variant_questions.append(text) - inputs = dict(input_ids=ids.long(), labels=labels) + inputs: dict[str, Tensor | None] = dict(input_ids=ids.long()) + if is_enc_dec or has_lm_preds: + inputs["labels"] = labels outputs = model(**inputs, output_hidden_states=True) # Compute the log probability of the answer tokens if available