Skip to content

Commit

Permalink
Fixed bug with encoder only models not expecting "labels" argument in…
Browse files Browse the repository at this point in the history
… forward pass (#264)

* Fixed bug with encoder only models not expecting "labels" argument in forward pass

* Fixed type annotations

* Fixed a bug with decoder-only models
  • Loading branch information
AugustasMacijauskas authored Jun 29, 2023
1 parent ec2b8a0 commit 4ec2289
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def extract_hiddens(
# Record the EXACT question we fed to the model
variant_questions.append(text)

inputs = dict(input_ids=ids.long(), labels=labels)
inputs: dict[str, Tensor | None] = dict(input_ids=ids.long())
if is_enc_dec or has_lm_preds:
inputs["labels"] = labels
outputs = model(**inputs, output_hidden_states=True)

# Compute the log probability of the answer tokens if available
Expand Down

0 comments on commit 4ec2289

Please sign in to comment.