Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Sep 29, 2023
1 parent ebf95b0 commit 22a804b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 28 deletions.
32 changes: 12 additions & 20 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from torch import Tensor
from transformers import AutoConfig

from ..utils.hf_utils import is_autoregressive

from ..utils import (
Color,
assert_type,
Expand All @@ -35,6 +33,7 @@
select_split,
select_train_val_splits,
)
from ..utils.hf_utils import is_autoregressive
from .dataset_name import (
DatasetDictWithName,
parse_dataset_string,
Expand Down Expand Up @@ -236,7 +235,7 @@ def tokenize_dataset(
variant_id=example["template_names"][i],
label=example["label"],
text=statement + suffix,
input_ids=ids.long()
input_ids=ids.long(),
)
if cfg.get_lm_preds:
out_record["answer_ids"] = answer_ids # type: ignore
Expand All @@ -246,7 +245,7 @@ def tokenize_dataset(

if any_too_long:
continue

# print an example text to stdout
if len(out_records) == 0:
print(f"Example text: {record_variants[0]['text']}")
Expand Down Expand Up @@ -317,9 +316,7 @@ def extract(

model_config = AutoConfig.from_pretrained(cfg.model)
if not is_autoregressive(model_config, include_enc_dec=True) and cfg.get_lm_preds:
raise ValueError(
"Can only extract LM predictions from autoregressive models."
)
raise ValueError("Can only extract LM predictions from autoregressive models.")
limits = cfg.max_examples
splits = assert_type(SplitDict, info.splits)

Expand All @@ -342,15 +339,10 @@ def extract(
else:
print(f"{pretty_name} using '{split_name}' for validation")


def select_hiddens(
outputs: Any, **kwargs: Any
) -> tuple[dict[str, Tensor], Tensor]:
def select_hiddens(outputs: Any, **kwargs: Any) -> tuple[dict[str, Tensor], Tensor]:
tok_loc_offset = kwargs.get("num_suffix_tokens", 0)
# Add one to the number of layers to account for the embedding layer
layer_indices = cfg.layers or tuple(
range(model_config.num_hidden_layers + 1)
)
layer_indices = cfg.layers or tuple(range(model_config.num_hidden_layers + 1))

hiddens = outputs.get("decoder_hidden_states") or outputs["hidden_states"]
# Throw out layers we don't care about
Expand All @@ -370,9 +362,7 @@ def select_hiddens(

hidden_dict = dict()
for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"] = float_to_int16(
hidden.flatten()
).cpu()
hidden_dict[f"hidden_{layer_idx}"] = float_to_int16(hidden.flatten()).cpu()

if (answer_ids := kwargs.get("answer_ids")) is not None:
# log_odds = log(p(yes)/(p(no)) = log(p(yes)) - log(p(no))
Expand All @@ -384,7 +374,6 @@ def select_hiddens(

return hidden_dict, lm_log_odds


def extract_hiddens(
cfg: Extract,
split_type: Literal["train", "val"],
Expand All @@ -399,7 +388,10 @@ def extract_hiddens(

buffer = defaultdict(list) # row_id -> list of dicts
for idx, (hidden_dict, lm_log_odds) in server.imap(
select_hiddens, encodings, use_tqdm=False, model_kwargs=dict(output_hidden_states=True)
select_hiddens,
encodings,
use_tqdm=False,
model_kwargs=dict(output_hidden_states=True),
):
encoding = encodings[idx]
row_id = encoding["row_id"]
Expand Down Expand Up @@ -430,7 +422,7 @@ def extract_hiddens(
def _extraction_worker(**kwargs):
yield from extract_hiddens(**{k: v[0] for k, v in kwargs.items()})

# TODO: support int8
# TODO: support int8
server = InferenceServer(
model_str=cfg.model, num_workers=num_gpus, cpu_offload=True, fsdp=cfg.fsdp
)
Expand Down
28 changes: 20 additions & 8 deletions elk/extraction/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@dataclass(frozen=True)
class _Sentinel:
"""Sentinel value used to indicate that a worker is done."""


SENTINEL = _Sentinel()

Expand Down Expand Up @@ -131,15 +131,29 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()

def map_forward(self, dataset: Dataset, model_kwargs: dict[str, Any] | None = None, use_tqdm: bool = False) -> list:
def map_forward(
self,
dataset: Dataset,
model_kwargs: dict[str, Any] | None = None,
use_tqdm: bool = False,
) -> list:
"""Maps the model's `forward` method over the given dataset, without
running a closure on the outputs."""
return self.map(lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm)
return self.map(
lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm
)

def imap_forward(self, dataset: Dataset, model_kwargs: dict[str, Any] | None = None, use_tqdm: bool = False) -> Iterable:
def imap_forward(
self,
dataset: Dataset,
model_kwargs: dict[str, Any] | None = None,
use_tqdm: bool = False,
) -> Iterable:
"""Maps the model's `forward` method over the given dataset, without
running a closure on the outputs."""
yield from self.imap(lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm)
yield from self.imap(
lambda x: x, dataset, model_kwargs=model_kwargs, use_tqdm=use_tqdm
)

def map(
self,
Expand Down Expand Up @@ -342,9 +356,7 @@ def maybe_unsqueeze(v):
outputs = model(**inputs_cuda, **model_kwargs)

if callable(closure):
outputs = closure(
outputs, **record
)
outputs = closure(outputs, **record)
if outputs is not None:
# Move the outputs back to the CPU
outputs = pytree_map(lambda x: x.cpu().share_memory_(), outputs)
Expand Down

0 comments on commit 22a804b

Please sign in to comment.