diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 2d83578a..f4e83fa1 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -23,8 +23,6 @@ from torch import Tensor from transformers import AutoConfig -from ..utils.hf_utils import is_autoregressive - from ..utils import ( Color, assert_type, @@ -35,6 +33,7 @@ select_split, select_train_val_splits, ) +from ..utils.hf_utils import is_autoregressive from .dataset_name import ( DatasetDictWithName, parse_dataset_string, @@ -236,7 +235,7 @@ def tokenize_dataset( variant_id=example["template_names"][i], label=example["label"], text=statement + suffix, - input_ids=ids.long() + input_ids=ids.long(), ) if cfg.get_lm_preds: out_record["answer_ids"] = answer_ids # type: ignore @@ -246,7 +245,7 @@ def tokenize_dataset( if any_too_long: continue - + # print an example text to stdout if len(out_records) == 0: print(f"Example text: {record_variants[0]['text']}") @@ -317,9 +316,7 @@ def extract( model_config = AutoConfig.from_pretrained(cfg.model) if not is_autoregressive(model_config, include_enc_dec=True) and cfg.get_lm_preds: - raise ValueError( - "Can only extract LM predictions from autoregressive models." - ) + raise ValueError("Can only extract LM predictions from autoregressive models.") limits = cfg.max_examples splits = assert_type(SplitDict, info.splits) @@ -342,15 +339,10 @@ def extract( else: print(f"{pretty_name} using '{split_name}' for validation") - - def select_hiddens( - outputs: Any, **kwargs: Any - ) -> tuple[dict[str, Tensor], Tensor]: + def select_hiddens(outputs: Any, **kwargs: Any) -> tuple[dict[str, Tensor], Tensor]: tok_loc_offset = kwargs.get("num_suffix_tokens", 0) # Add one to the number of layers to account for the embedding layer - layer_indices = cfg.layers or tuple( - range(model_config.num_hidden_layers + 1) - ) + layer_indices = cfg.layers or tuple(range(model_config.num_hidden_layers + 1)) hiddens = outputs.get("decoder_hidden_states") or outputs["hidden_states"] # Throw out layers we don't care about @@ -370,9 +362,7 @@ def select_hiddens( hidden_dict = dict() for layer_idx, hidden in zip(layer_indices, hiddens): - hidden_dict[f"hidden_{layer_idx}"] = float_to_int16( - hidden.flatten() - ).cpu() + hidden_dict[f"hidden_{layer_idx}"] = float_to_int16(hidden.flatten()).cpu() if (answer_ids := kwargs.get("answer_ids")) is not None: # log_odds = log(p(yes)/(p(no)) = log(p(yes)) - log(p(no)) @@ -384,7 +374,6 @@ def select_hiddens( return hidden_dict, lm_log_odds - def extract_hiddens( cfg: Extract, split_type: Literal["train", "val"], @@ -399,7 +388,10 @@ def extract_hiddens( buffer = defaultdict(list) # row_id -> list of dicts for idx, (hidden_dict, lm_log_odds) in server.imap( - select_hiddens, encodings, use_tqdm=False, model_kwargs=dict(output_hidden_states=True) + select_hiddens, + encodings, + use_tqdm=False, + model_kwargs=dict(output_hidden_states=True), ): encoding = encodings[idx] row_id = encoding["row_id"] @@ -430,7 +422,7 @@ def extract_hiddens( def _extraction_worker(**kwargs): yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) - # TODO: support int8 + # TODO: support int8 server = InferenceServer( model_str=cfg.model, num_workers=num_gpus, cpu_offload=True, fsdp=cfg.fsdp ) diff --git a/elk/extraction/inference_server.py b/elk/extraction/inference_server.py index 2026f361..9c39ec41 100644 --- a/elk/extraction/inference_server.py +++ b/elk/extraction/inference_server.py @@ -27,7 +27,7 @@ @dataclass(frozen=True) class _Sentinel: """Sentinel value used to indicate that a worker is done.""" - + SENTINEL = _Sentinel() @@ -131,15 +131,29 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() - def map_forward(self, dataset: Dataset, model_kwargs: dict[str, Any] | None = None, use_tqdm: bool = False) -> list: + def map_forward( + self, + dataset: Dataset, + model_kwargs: dict[str, Any] | None = None, + use_tqdm: bool = False, + ) -> list: """Maps the model's `forward` method over the given dataset, without running a closure on the outputs.""" - return self.map(lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm) + return self.map( + lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm + ) - def imap_forward(self, dataset: Dataset, model_kwargs: dict[str, Any] | None = None, use_tqdm: bool = False) -> Iterable: + def imap_forward( + self, + dataset: Dataset, + model_kwargs: dict[str, Any] | None = None, + use_tqdm: bool = False, + ) -> Iterable: """Maps the model's `forward` method over the given dataset, without running a closure on the outputs.""" - yield from self.imap(lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm) + yield from self.imap( + lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm + ) def map( self, @@ -342,9 +356,7 @@ def maybe_unsqueeze(v): outputs = model(**inputs_cuda, **model_kwargs) if callable(closure): - outputs = closure( - outputs, **record - ) + outputs = closure(outputs, **record) if outputs is not None: # Move the outputs back to the CPU outputs = pytree_map(lambda x: x.cpu().share_memory_(), outputs)