From 3bc362ece0dcfd50ac246e4ad1213d0e328e109f Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Wed, 20 Sep 2023 21:03:56 +0000 Subject: [PATCH] don't load model when using cache --- elk/extraction/extraction.py | 243 ++++++++++++++++++----------------- elk/extraction/generator.py | 2 +- 2 files changed, 128 insertions(+), 117 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 4bc3f5c3..f25d26af 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,12 +1,9 @@ """Functions for extracting the hidden states of a model.""" -import logging import os from collections import defaultdict -from contextlib import nullcontext, redirect_stdout from dataclasses import InitVar, dataclass, replace from itertools import zip_longest from typing import Any, Iterable, Literal -from warnings import filterwarnings import torch from datasets import ( @@ -42,8 +39,9 @@ parse_dataset_string, ) from .generator import _GeneratorBuilder -from .prompt_loading import load_prompts from .inference_server import InferenceServer +from .prompt_loading import load_prompts + @dataclass class Extract(Serializable): @@ -159,9 +157,9 @@ def get_encodings( # welcome message on every rank tokenizer = instantiate_tokenizer(cfg.model, truncation_side="left") - model_config = AutoConfig.from_pretrained(cfg.model) + AutoConfig.from_pretrained(cfg.model) # TODO: support using the encoder only of an encoder-decoder model - + prompt_ds = load_prompts( ds_names[0], num_shots=cfg.num_shots, @@ -206,7 +204,7 @@ def get_encodings( inputs: dict[str, Tensor | None | bool] = dict(input_ids=ids.long()) inputs["output_hidden_states"] = True - + out_record: dict[str, Any] = dict( row_id=example["row_id"], variant_id=example["template_names"][i], @@ -216,11 +214,11 @@ def get_encodings( ) record_variants.append(out_record) - + if any_too_long: continue out_records.extend(record_variants) - + # transpose the list of dicts into a dict of lists out_records = {k: [d[k] for d in out_records] for k in out_records[0]} return Dataset.from_dict(out_records) @@ -279,120 +277,133 @@ def extract( split_type: Literal["train", "val", None] = None, ) -> DatasetDictWithName: """Extract hidden states from a model and return a `DatasetDict` containing them.""" - with InferenceServer(model_str=cfg.model, num_workers=num_gpus, cpu_offload=True, fsdp=cfg.fsdp) as server: - print(f"Using {server.num_workers} workers for inference") - info, features = hidden_features(cfg) + info, features = hidden_features(cfg) - model_config = AutoConfig.from_pretrained(cfg.model) - limits = cfg.max_examples - splits = assert_type(SplitDict, info.splits) + model_config = AutoConfig.from_pretrained(cfg.model) + limits = cfg.max_examples + splits = assert_type(SplitDict, info.splits) - pretty_name = colorize(assert_type(str, cfg.datasets[0]), highlight_color) - if split_type is None: - train, val = select_train_val_splits(splits) + pretty_name = colorize(assert_type(str, cfg.datasets[0]), highlight_color) + if split_type is None: + train, val = select_train_val_splits(splits) - print(f"{pretty_name} using '{train}' for training and '{val}' for validation") - splits = SplitDict({train: splits[train], val: splits[val]}) - split_types = ["train", "val"] + print(f"{pretty_name} using '{train}' for training and '{val}' for validation") + splits = SplitDict({train: splits[train], val: splits[val]}) + split_types = ["train", "val"] + else: + # Remove the split we're not using + limits = [limits[0]] if split_type == "train" else limits + split_name = select_split(splits, split_type) + splits = SplitDict({split_name: splits[split_name]}) + split_types = [split_type] + + if split_type == "train": + print(f"{pretty_name} using '{split_name}' for training") else: - # Remove the split we're not using - limits = [limits[0]] if split_type == "train" else limits - split_name = select_split(splits, split_type) - splits = SplitDict({split_name: splits[split_name]}) - split_types = [split_type] - - if split_type == "train": - print(f"{pretty_name} using '{split_name}' for training") - else: - print(f"{pretty_name} using '{split_name}' for validation") - - # define _extraction_worker in this context to yield modified outputs from server.imap - def extract_hiddens( - cfg: Extract, - split_type: Literal["train", "val"], - ) -> Iterable[dict]: - - encodings = get_encodings(cfg, split_type=split_type) - num_variants = len(encodings.unique("variant_id")) - def select_hiddens(outputs: Any) -> dict: - # Add one to the number of layers to account for the embedding layer - layer_indices = cfg.layers or tuple(range(model_config.num_hidden_layers + 1)) - - hiddens = outputs.get("decoder_hidden_states") or outputs["hidden_states"] - # Throw out layers we don't care about - hiddens = [hiddens[i] for i in layer_indices] - - # Current shape of each element: (batch_size, seq_len, hidden_size) - if cfg.token_loc == "first": - hiddens = [h[..., 0, :] for h in hiddens] - elif cfg.token_loc == "last": - hiddens = [h[..., -1, :] for h in hiddens] - elif cfg.token_loc == "penultimate": - hiddens = [h[..., -2, :] for h in hiddens] - elif cfg.token_loc == "mean": - hiddens = [h.mean(dim=-2) for h in hiddens] - else: - raise ValueError(f"Invalid token_loc: {cfg.token_loc}") - - hidden_dict = dict() - for layer_idx, hidden in zip(layer_indices, hiddens): - hidden_dict[f"hidden_{layer_idx}"] = float_to_int16(hidden.flatten()).cpu() - - return hidden_dict - - encodings = encodings.add_column("id", range(len(encodings))) # type: ignore[attr-defined] - buffer = defaultdict(list) # row_id -> list of dicts - for idx, hidden_dict in server.imap(select_hiddens, encodings): - encoding = encodings[idx] - row_id = encoding["row_id"] - buffer[row_id].append(dict(**encoding, **hidden_dict)) - if len(buffer[row_id]) == num_variants: - # we have a complete example - ex = buffer[row_id] - assert all(d["label"] == ex[0]["label"] for d in ex) - assert len(set(d["variant_id"] for d in ex)) == num_variants - out_record = dict( - variant_ids=[d["variant_id"] for d in ex], - label=ex[0]["label"], - row_id=ex[0]["row_id"], - texts=[d["text"] for d in ex], - **{k: torch.stack([d[k] for d in ex]) for k in hidden_dict}, - ) - del buffer[row_id] - yield out_record - - def _extraction_worker(**kwargs): - yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) - - builders = { - split_name: _GeneratorBuilder( - cache_dir=None, - features=features, - generator=_extraction_worker, - split_name=split_name, - split_info=SplitInfo( - name=split_name, - num_examples=min(limit, v.num_examples) * len(cfg.datasets), - dataset_name=v.dataset_name, - ), - gen_kwargs=dict( - cfg=[cfg], - split_type=[ty], - ), - ) - for limit, (split_name, v), ty in zip(limits, splits.items(), split_types) - } - - ds = dict() - for split, builder in builders.items(): - builder.download_and_prepare( - download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None, - num_proc=None, + print(f"{pretty_name} using '{split_name}' for validation") + + server = InferenceServer( + model_str=cfg.model, num_workers=num_gpus, cpu_offload=True, fsdp=cfg.fsdp + ) + + def extract_hiddens( + cfg: Extract, + split_type: Literal["train", "val"], + server: InferenceServer, + ) -> Iterable[dict]: + encodings = get_encodings(cfg, split_type=split_type) + num_variants = len(encodings.unique("variant_id")) + + def select_hiddens(outputs: Any) -> dict: + # Add one to the number of layers to account for the embedding layer + layer_indices = cfg.layers or tuple( + range(model_config.num_hidden_layers + 1) ) - ds[split] = builder.as_dataset(split=split) # type: ignore[assignment] - dataset_dict = DatasetDict(ds) + hiddens = outputs.get("decoder_hidden_states") or outputs["hidden_states"] + # Throw out layers we don't care about + hiddens = [hiddens[i] for i in layer_indices] + + # Current shape of each element: (batch_size, seq_len, hidden_size) + if cfg.token_loc == "first": + hiddens = [h[..., 0, :] for h in hiddens] + elif cfg.token_loc == "last": + hiddens = [h[..., -1, :] for h in hiddens] + elif cfg.token_loc == "penultimate": + hiddens = [h[..., -2, :] for h in hiddens] + elif cfg.token_loc == "mean": + hiddens = [h.mean(dim=-2) for h in hiddens] + else: + raise ValueError(f"Invalid token_loc: {cfg.token_loc}") + + hidden_dict = dict() + for layer_idx, hidden in zip(layer_indices, hiddens): + hidden_dict[f"hidden_{layer_idx}"] = float_to_int16( + hidden.flatten() + ).cpu() + + return hidden_dict + + if not server.running: + server.start() + encodings = encodings.add_column("id", range(len(encodings))) # type: ignore + buffer = defaultdict(list) # row_id -> list of dicts + for idx, hidden_dict in server.imap(select_hiddens, encodings, use_tqdm=False): + encoding = encodings[idx] + row_id = encoding["row_id"] + buffer[row_id].append(dict(**encoding, **hidden_dict)) + if len(buffer[row_id]) == num_variants: + # we have a complete example + ex = buffer[row_id] + assert all(d["label"] == ex[0]["label"] for d in ex) + assert len(set(d["variant_id"] for d in ex)) == num_variants + out_record = dict( + variant_ids=[d["variant_id"] for d in ex], + label=ex[0]["label"], + row_id=ex[0]["row_id"], + texts=[d["text"] for d in ex], + **{k: torch.stack([d[k] for d in ex]) for k in hidden_dict}, + ) + del buffer[row_id] + yield out_record + + # hf wraps everything in a list here, so we unpack them here + def _extraction_worker(**kwargs): + yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()}) + + builders = { + split_name: _GeneratorBuilder( + cache_dir=None, + features=features, + generator=_extraction_worker, + split_name=split_name, + split_info=SplitInfo( + name=split_name, + num_examples=min(limit, v.num_examples) * len(cfg.datasets), + dataset_name=v.dataset_name, + ), + gen_kwargs=dict( + cfg=[cfg], + split_type=[ty], + server=[server], + ), + ) + for limit, (split_name, v), ty in zip(limits, splits.items(), split_types) + } + + ds = dict() + for split, builder in builders.items(): + builder.download_and_prepare( + download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None, + num_proc=None, + ) + ds[split] = builder.as_dataset(split=split) # type: ignore[assignment] + + if server.running: + server.shutdown() + + dataset_dict = DatasetDict(ds) return DatasetDictWithName( name=cfg.datasets[0], diff --git a/elk/extraction/generator.py b/elk/extraction/generator.py index 686804f9..9e6e2777 100644 --- a/elk/extraction/generator.py +++ b/elk/extraction/generator.py @@ -30,7 +30,7 @@ def create_config_id( config_kwargs["gen_kwargs"] = { k: v[0] for k, v in config_kwargs.get("gen_kwargs", {}).items() - if k not in ("device", "rank", "world_size") + if k not in ("device", "rank", "world_size", "server") } config_kwargs.pop("generator") # pickling InferenceServer fails return super().create_config_id(config_kwargs, custom_features)