Skip to content

Commit

Permalink
try caching kvs
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed May 6, 2023
1 parent 55b18ab commit 03cb047
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
39 changes: 32 additions & 7 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from dataclasses import InitVar, dataclass, replace
from itertools import zip_longest
from typing import Any, Iterable, Literal
from typing import Any, Iterable, Literal, Optional
from warnings import filterwarnings

import torch
Expand Down Expand Up @@ -147,6 +147,24 @@ def explode(self) -> list["Extract"]:
]


def get_reusable_kv(
answer_ids: torch.Tensor, past_key_values: tuple[tuple[torch.Tensor]]
) -> tuple[tuple[torch.Tensor]]:
reshaped_past_key_values = ()
answer_length = answer_ids.size(-1)

for layer_kv in past_key_values:
reshaped_kv = ()
for t in layer_kv:
# Tensor is (batch_size, num_heads, sequence_length, embed_size_per_head)
# Reduce the sequence_length by answer_length to remove the token
reshaped_tensor = t[:, :, :-answer_length, :]
reshaped_kv += (reshaped_tensor,)
reshaped_past_key_values += (reshaped_kv,)

return reshaped_past_key_values


@torch.inference_mode()
def extract_hiddens(
cfg: "Extract",
Expand Down Expand Up @@ -247,7 +265,7 @@ def extract_hiddens(
# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_questions = []

cached_kv: tuple[tuple[Tensor]] | None = None
# Iterate over answers
for j, choice in enumerate(record):
text = choice["question"]
Expand All @@ -264,6 +282,7 @@ def extract_hiddens(
input_ids = assert_type(Tensor, encoding.input_ids)
if is_enc_dec:
answer = assert_type(Tensor, encoding.labels)
input_ids_to_pass = input_ids
else:
encoding2 = tokenizer(
choice["answer"],
Expand All @@ -272,22 +291,29 @@ def extract_hiddens(
return_tensors="pt",
).to(first_device)
answer = assert_type(Tensor, encoding2.input_ids)
input_ids = torch.cat([input_ids, answer], dim=-1)
input_ids_to_pass = torch.cat([input_ids, answer], dim=-1)
# for decoders, we just need to pass the cached key-values

# If this input is too long, skip it
if input_ids.shape[-1] > max_length:
if input_ids_to_pass.shape[-1] > max_length:
break
else:
# Record the EXACT question we fed to the model
variant_questions.append(text)

# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids.long())
inputs = dict(input_ids=input_ids_to_pass.long())
if is_enc_dec:
inputs["labels"] = answer
if cached_kv is not None:
inputs["past_key_values"] = cached_kv
inputs["input_ids"] = input_ids

outputs = model(**inputs, output_hidden_states=True)
outputs = model(**inputs, output_hidden_states=True, use_cache=True)

cached_kv = get_reusable_kv(
answer_ids=answer, past_key_values=outputs.past_key_values
)
# Compute the log probability of the answer tokens if available
if has_lm_preds:
answer_len = answer.shape[-1]
Expand Down Expand Up @@ -320,7 +346,6 @@ def extract_hiddens(

for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float_to_int16(hidden)

# We skipped a pseudolabel because it was too long; break out of this whole
# example and move on to the next one
if len(variant_questions) != num_choices:
Expand Down
21 changes: 19 additions & 2 deletions tests/test_smoke_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import pandas as pd

from elk import Extract
from elk import Extract, extract_hiddens
from elk.evaluation import Eval
from elk.training import CcsReporterConfig, EigenReporterConfig
from elk.training.train import Elicit
from elk.utils.multi_gpu import ModelDevices

EVAL_EXPECTED_FILES = [
"cfg.yaml",
Expand All @@ -19,7 +20,7 @@ def setup_elicit(
tmp_path: Path,
dataset_name="imdb",
model_path="sshleifer/tiny-gpt2",
min_mem=10 * 1024 ** 2,
min_mem=10 * 1024**2,
is_ccs: bool = True,
) -> Elicit:
"""Setup elicit config for testing, execute elicit, and save output to tmp_path.
Expand Down Expand Up @@ -96,6 +97,22 @@ def test_smoke_tfr_eval_run_tiny_gpt2_ccs(tmp_path: Path):
eval_assert_files_created(elicit, transfer_datasets=transfer_datasets)


def test_extract():
for i in extract_hiddens(
cfg=Extract(
model="sshleifer/tiny-gpt2",
datasets=("imdb",),
max_examples=(10, 10),
# run on all layers, tiny-gpt only has 2 layers
),
devices=ModelDevices(first_device="cpu", other_devices=[]),
split_type="train",
rank=0,
world_size=1,
):
print(i)


def test_smoke_eval_run_tiny_gpt2_eigen(tmp_path: Path):
elicit = setup_elicit(tmp_path, is_ccs=False)
transfer_datasets = ("christykoh/imdb_pt",)
Expand Down

0 comments on commit 03cb047

Please sign in to comment.