Skip to content

Commit

Permalink
don't load model when using cache
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Sep 20, 2023
1 parent 1a04ee8 commit 3bc362e
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 117 deletions.
243 changes: 127 additions & 116 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""Functions for extracting the hidden states of a model."""
import logging
import os
from collections import defaultdict
from contextlib import nullcontext, redirect_stdout
from dataclasses import InitVar, dataclass, replace
from itertools import zip_longest
from typing import Any, Iterable, Literal
from warnings import filterwarnings

import torch
from datasets import (
Expand Down Expand Up @@ -42,8 +39,9 @@
parse_dataset_string,
)
from .generator import _GeneratorBuilder
from .prompt_loading import load_prompts
from .inference_server import InferenceServer
from .prompt_loading import load_prompts


@dataclass
class Extract(Serializable):
Expand Down Expand Up @@ -159,9 +157,9 @@ def get_encodings(
# welcome message on every rank
tokenizer = instantiate_tokenizer(cfg.model, truncation_side="left")

model_config = AutoConfig.from_pretrained(cfg.model)
AutoConfig.from_pretrained(cfg.model)
# TODO: support using the encoder only of an encoder-decoder model

prompt_ds = load_prompts(
ds_names[0],
num_shots=cfg.num_shots,
Expand Down Expand Up @@ -206,7 +204,7 @@ def get_encodings(

inputs: dict[str, Tensor | None | bool] = dict(input_ids=ids.long())
inputs["output_hidden_states"] = True

out_record: dict[str, Any] = dict(
row_id=example["row_id"],
variant_id=example["template_names"][i],
Expand All @@ -216,11 +214,11 @@ def get_encodings(
)

record_variants.append(out_record)

if any_too_long:
continue
out_records.extend(record_variants)

# transpose the list of dicts into a dict of lists
out_records = {k: [d[k] for d in out_records] for k in out_records[0]}
return Dataset.from_dict(out_records)
Expand Down Expand Up @@ -279,120 +277,133 @@ def extract(
split_type: Literal["train", "val", None] = None,
) -> DatasetDictWithName:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""
with InferenceServer(model_str=cfg.model, num_workers=num_gpus, cpu_offload=True, fsdp=cfg.fsdp) as server:
print(f"Using {server.num_workers} workers for inference")

info, features = hidden_features(cfg)
info, features = hidden_features(cfg)

model_config = AutoConfig.from_pretrained(cfg.model)
limits = cfg.max_examples
splits = assert_type(SplitDict, info.splits)
model_config = AutoConfig.from_pretrained(cfg.model)
limits = cfg.max_examples
splits = assert_type(SplitDict, info.splits)

pretty_name = colorize(assert_type(str, cfg.datasets[0]), highlight_color)
if split_type is None:
train, val = select_train_val_splits(splits)
pretty_name = colorize(assert_type(str, cfg.datasets[0]), highlight_color)
if split_type is None:
train, val = select_train_val_splits(splits)

print(f"{pretty_name} using '{train}' for training and '{val}' for validation")
splits = SplitDict({train: splits[train], val: splits[val]})
split_types = ["train", "val"]
print(f"{pretty_name} using '{train}' for training and '{val}' for validation")
splits = SplitDict({train: splits[train], val: splits[val]})
split_types = ["train", "val"]
else:
# Remove the split we're not using
limits = [limits[0]] if split_type == "train" else limits
split_name = select_split(splits, split_type)
splits = SplitDict({split_name: splits[split_name]})
split_types = [split_type]

if split_type == "train":
print(f"{pretty_name} using '{split_name}' for training")
else:
# Remove the split we're not using
limits = [limits[0]] if split_type == "train" else limits
split_name = select_split(splits, split_type)
splits = SplitDict({split_name: splits[split_name]})
split_types = [split_type]

if split_type == "train":
print(f"{pretty_name} using '{split_name}' for training")
else:
print(f"{pretty_name} using '{split_name}' for validation")

# define _extraction_worker in this context to yield modified outputs from server.imap
def extract_hiddens(
cfg: Extract,
split_type: Literal["train", "val"],
) -> Iterable[dict]:

encodings = get_encodings(cfg, split_type=split_type)
num_variants = len(encodings.unique("variant_id"))
def select_hiddens(outputs: Any) -> dict:
# Add one to the number of layers to account for the embedding layer
layer_indices = cfg.layers or tuple(range(model_config.num_hidden_layers + 1))

hiddens = outputs.get("decoder_hidden_states") or outputs["hidden_states"]
# Throw out layers we don't care about
hiddens = [hiddens[i] for i in layer_indices]

# Current shape of each element: (batch_size, seq_len, hidden_size)
if cfg.token_loc == "first":
hiddens = [h[..., 0, :] for h in hiddens]
elif cfg.token_loc == "last":
hiddens = [h[..., -1, :] for h in hiddens]
elif cfg.token_loc == "penultimate":
hiddens = [h[..., -2, :] for h in hiddens]
elif cfg.token_loc == "mean":
hiddens = [h.mean(dim=-2) for h in hiddens]
else:
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")

hidden_dict = dict()
for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"] = float_to_int16(hidden.flatten()).cpu()

return hidden_dict

encodings = encodings.add_column("id", range(len(encodings))) # type: ignore[attr-defined]
buffer = defaultdict(list) # row_id -> list of dicts
for idx, hidden_dict in server.imap(select_hiddens, encodings):
encoding = encodings[idx]
row_id = encoding["row_id"]
buffer[row_id].append(dict(**encoding, **hidden_dict))
if len(buffer[row_id]) == num_variants:
# we have a complete example
ex = buffer[row_id]
assert all(d["label"] == ex[0]["label"] for d in ex)
assert len(set(d["variant_id"] for d in ex)) == num_variants
out_record = dict(
variant_ids=[d["variant_id"] for d in ex],
label=ex[0]["label"],
row_id=ex[0]["row_id"],
texts=[d["text"] for d in ex],
**{k: torch.stack([d[k] for d in ex]) for k in hidden_dict},
)
del buffer[row_id]
yield out_record

def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})

builders = {
split_name: _GeneratorBuilder(
cache_dir=None,
features=features,
generator=_extraction_worker,
split_name=split_name,
split_info=SplitInfo(
name=split_name,
num_examples=min(limit, v.num_examples) * len(cfg.datasets),
dataset_name=v.dataset_name,
),
gen_kwargs=dict(
cfg=[cfg],
split_type=[ty],
),
)
for limit, (split_name, v), ty in zip(limits, splits.items(), split_types)
}

ds = dict()
for split, builder in builders.items():
builder.download_and_prepare(
download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None,
num_proc=None,
print(f"{pretty_name} using '{split_name}' for validation")

server = InferenceServer(
model_str=cfg.model, num_workers=num_gpus, cpu_offload=True, fsdp=cfg.fsdp
)

def extract_hiddens(
cfg: Extract,
split_type: Literal["train", "val"],
server: InferenceServer,
) -> Iterable[dict]:
encodings = get_encodings(cfg, split_type=split_type)
num_variants = len(encodings.unique("variant_id"))

def select_hiddens(outputs: Any) -> dict:
# Add one to the number of layers to account for the embedding layer
layer_indices = cfg.layers or tuple(
range(model_config.num_hidden_layers + 1)
)
ds[split] = builder.as_dataset(split=split) # type: ignore[assignment]

dataset_dict = DatasetDict(ds)
hiddens = outputs.get("decoder_hidden_states") or outputs["hidden_states"]
# Throw out layers we don't care about
hiddens = [hiddens[i] for i in layer_indices]

# Current shape of each element: (batch_size, seq_len, hidden_size)
if cfg.token_loc == "first":
hiddens = [h[..., 0, :] for h in hiddens]
elif cfg.token_loc == "last":
hiddens = [h[..., -1, :] for h in hiddens]
elif cfg.token_loc == "penultimate":
hiddens = [h[..., -2, :] for h in hiddens]
elif cfg.token_loc == "mean":
hiddens = [h.mean(dim=-2) for h in hiddens]
else:
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")

hidden_dict = dict()
for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"] = float_to_int16(
hidden.flatten()
).cpu()

return hidden_dict

if not server.running:
server.start()
encodings = encodings.add_column("id", range(len(encodings))) # type: ignore
buffer = defaultdict(list) # row_id -> list of dicts
for idx, hidden_dict in server.imap(select_hiddens, encodings, use_tqdm=False):
encoding = encodings[idx]
row_id = encoding["row_id"]
buffer[row_id].append(dict(**encoding, **hidden_dict))
if len(buffer[row_id]) == num_variants:
# we have a complete example
ex = buffer[row_id]
assert all(d["label"] == ex[0]["label"] for d in ex)
assert len(set(d["variant_id"] for d in ex)) == num_variants
out_record = dict(
variant_ids=[d["variant_id"] for d in ex],
label=ex[0]["label"],
row_id=ex[0]["row_id"],
texts=[d["text"] for d in ex],
**{k: torch.stack([d[k] for d in ex]) for k in hidden_dict},
)
del buffer[row_id]
yield out_record

# hf wraps everything in a list here, so we unpack them here
def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})

builders = {
split_name: _GeneratorBuilder(
cache_dir=None,
features=features,
generator=_extraction_worker,
split_name=split_name,
split_info=SplitInfo(
name=split_name,
num_examples=min(limit, v.num_examples) * len(cfg.datasets),
dataset_name=v.dataset_name,
),
gen_kwargs=dict(
cfg=[cfg],
split_type=[ty],
server=[server],
),
)
for limit, (split_name, v), ty in zip(limits, splits.items(), split_types)
}

ds = dict()
for split, builder in builders.items():
builder.download_and_prepare(
download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None,
num_proc=None,
)
ds[split] = builder.as_dataset(split=split) # type: ignore[assignment]

if server.running:
server.shutdown()

dataset_dict = DatasetDict(ds)

return DatasetDictWithName(
name=cfg.datasets[0],
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_config_id(
config_kwargs["gen_kwargs"] = {
k: v[0]
for k, v in config_kwargs.get("gen_kwargs", {}).items()
if k not in ("device", "rank", "world_size")
if k not in ("device", "rank", "world_size", "server")
}
config_kwargs.pop("generator") # pickling InferenceServer fails
return super().create_config_id(config_kwargs, custom_features)
Expand Down

0 comments on commit 3bc362e

Please sign in to comment.