Skip to content

Commit

Permalink
print if enc dec
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed May 6, 2023
1 parent 8542161 commit 71b0850
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def extract_hiddens(
assert hasattr(model, "get_encoder") and callable(model.get_encoder)
model = assert_type(PreTrainedModel, model.get_encoder())
is_enc_dec = False
print("Is the model encoder-decoder:", is_enc_dec)

has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states)
if has_lm_preds and rank == 0:
Expand Down Expand Up @@ -312,8 +313,12 @@ def extract_hiddens(

outputs = model(**inputs, output_hidden_states=True, use_cache=True)

cached_question_kv = get_reusable_kv(
answer_ids=answer, past_key_values=outputs.past_key_values
cached_question_kv = (
get_reusable_kv(
answer_ids=answer, past_key_values=outputs.past_key_values
)
if cached_question_kv is None
else cached_question_kv
)
# Compute the log probability of the answer tokens if available
if has_lm_preds:
Expand Down

0 comments on commit 71b0850

Please sign in to comment.