From 03cb047cd9fd05502d82186906d7b0c583bc9d4f Mon Sep 17 00:00:00 2001 From: James Chua Date: Sat, 6 May 2023 15:27:02 +0800 Subject: [PATCH] try caching kvs --- elk/extraction/extraction.py | 39 +++++++++++++++++++++++++++++------- tests/test_smoke_eval.py | 21 +++++++++++++++++-- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index ad040ed6..b569e971 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -3,7 +3,7 @@ import os from dataclasses import InitVar, dataclass, replace from itertools import zip_longest -from typing import Any, Iterable, Literal +from typing import Any, Iterable, Literal, Optional from warnings import filterwarnings import torch @@ -147,6 +147,24 @@ def explode(self) -> list["Extract"]: ] +def get_reusable_kv( + answer_ids: torch.Tensor, past_key_values: tuple[tuple[torch.Tensor]] +) -> tuple[tuple[torch.Tensor]]: + reshaped_past_key_values = () + answer_length = answer_ids.size(-1) + + for layer_kv in past_key_values: + reshaped_kv = () + for t in layer_kv: + # Tensor is (batch_size, num_heads, sequence_length, embed_size_per_head) + # Reduce the sequence_length by answer_length to remove the token + reshaped_tensor = t[:, :, :-answer_length, :] + reshaped_kv += (reshaped_tensor,) + reshaped_past_key_values += (reshaped_kv,) + + return reshaped_past_key_values + + @torch.inference_mode() def extract_hiddens( cfg: "Extract", @@ -247,7 +265,7 @@ def extract_hiddens( # Iterate over variants for i, record in enumerate(example["prompts"]): variant_questions = [] - + cached_kv: tuple[tuple[Tensor]] | None = None # Iterate over answers for j, choice in enumerate(record): text = choice["question"] @@ -264,6 +282,7 @@ def extract_hiddens( input_ids = assert_type(Tensor, encoding.input_ids) if is_enc_dec: answer = assert_type(Tensor, encoding.labels) + input_ids_to_pass = input_ids else: encoding2 = tokenizer( choice["answer"], @@ -272,22 +291,29 @@ def extract_hiddens( return_tensors="pt", ).to(first_device) answer = assert_type(Tensor, encoding2.input_ids) - input_ids = torch.cat([input_ids, answer], dim=-1) + input_ids_to_pass = torch.cat([input_ids, answer], dim=-1) + # for decoders, we just need to pass the cached key-values # If this input is too long, skip it - if input_ids.shape[-1] > max_length: + if input_ids_to_pass.shape[-1] > max_length: break else: # Record the EXACT question we fed to the model variant_questions.append(text) # Make sure we only pass the arguments that the model expects - inputs = dict(input_ids=input_ids.long()) + inputs = dict(input_ids=input_ids_to_pass.long()) if is_enc_dec: inputs["labels"] = answer + if cached_kv is not None: + inputs["past_key_values"] = cached_kv + inputs["input_ids"] = input_ids - outputs = model(**inputs, output_hidden_states=True) + outputs = model(**inputs, output_hidden_states=True, use_cache=True) + cached_kv = get_reusable_kv( + answer_ids=answer, past_key_values=outputs.past_key_values + ) # Compute the log probability of the answer tokens if available if has_lm_preds: answer_len = answer.shape[-1] @@ -320,7 +346,6 @@ def extract_hiddens( for layer_idx, hidden in zip(layer_indices, hiddens): hidden_dict[f"hidden_{layer_idx}"][i, j] = float_to_int16(hidden) - # We skipped a pseudolabel because it was too long; break out of this whole # example and move on to the next one if len(variant_questions) != num_choices: diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index 683e718a..ac1f1f1c 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -2,10 +2,11 @@ import pandas as pd -from elk import Extract +from elk import Extract, extract_hiddens from elk.evaluation import Eval from elk.training import CcsReporterConfig, EigenReporterConfig from elk.training.train import Elicit +from elk.utils.multi_gpu import ModelDevices EVAL_EXPECTED_FILES = [ "cfg.yaml", @@ -19,7 +20,7 @@ def setup_elicit( tmp_path: Path, dataset_name="imdb", model_path="sshleifer/tiny-gpt2", - min_mem=10 * 1024 ** 2, + min_mem=10 * 1024**2, is_ccs: bool = True, ) -> Elicit: """Setup elicit config for testing, execute elicit, and save output to tmp_path. @@ -96,6 +97,22 @@ def test_smoke_tfr_eval_run_tiny_gpt2_ccs(tmp_path: Path): eval_assert_files_created(elicit, transfer_datasets=transfer_datasets) +def test_extract(): + for i in extract_hiddens( + cfg=Extract( + model="sshleifer/tiny-gpt2", + datasets=("imdb",), + max_examples=(10, 10), + # run on all layers, tiny-gpt only has 2 layers + ), + devices=ModelDevices(first_device="cpu", other_devices=[]), + split_type="train", + rank=0, + world_size=1, + ): + print(i) + + def test_smoke_eval_run_tiny_gpt2_eigen(tmp_path: Path): elicit = setup_elicit(tmp_path, is_ccs=False) transfer_datasets = ("christykoh/imdb_pt",)