Skip to content

Commit

Permalink
fix questio nkv
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed May 6, 2023
1 parent 03cb047 commit 8542161
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def extract_hiddens(
# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_questions = []
cached_kv: tuple[tuple[Tensor]] | None = None
cached_question_kv: tuple[tuple[Tensor]] | None = None
# Iterate over answers
for j, choice in enumerate(record):
text = choice["question"]
Expand Down Expand Up @@ -305,13 +305,14 @@ def extract_hiddens(
inputs = dict(input_ids=input_ids_to_pass.long())
if is_enc_dec:
inputs["labels"] = answer
if cached_kv is not None:
inputs["past_key_values"] = cached_kv
inputs["input_ids"] = input_ids
if cached_question_kv is not None:
# If we cached the question, all we need to pass is the answer
inputs["past_key_values"] = cached_question_kv
inputs["input_ids"] = answer

outputs = model(**inputs, output_hidden_states=True, use_cache=True)

cached_kv = get_reusable_kv(
cached_question_kv = get_reusable_kv(
answer_ids=answer, past_key_values=outputs.past_key_values
)
# Compute the log probability of the answer tokens if available
Expand Down

0 comments on commit 8542161

Please sign in to comment.