Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed May 3, 2023
2 parents 69bbf64 + 8ba18c3 commit ae23b59
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 111 deletions.
2 changes: 1 addition & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def apply_to_layer(
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter: Reporter = torch.load(reporter_path, map_location=device)
reporter = Reporter.load(reporter_path, map_location=device)
reporter.eval()

row_bufs = defaultdict(list)
Expand Down
53 changes: 39 additions & 14 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from dataclasses import InitVar, dataclass, replace
from itertools import islice, zip_longest
from itertools import zip_longest
from typing import Any, Iterable, Literal
from warnings import filterwarnings

Expand Down Expand Up @@ -30,7 +30,7 @@
Color,
assert_type,
colorize,
float32_to_int16,
float_to_int16,
infer_label_column,
infer_num_classes,
instantiate_tokenizer,
Expand Down Expand Up @@ -102,6 +102,11 @@ class Extract(Serializable):
case of encoder-decoder models."""

def __post_init__(self, layer_stride: int):
if len(self.datasets) == 0:
raise ValueError(
"Must specify at least one dataset to extract hiddens from."
)

if len(self.max_examples) > 2:
raise ValueError(
"max_examples should be a list of length 0, 1, or 2,"
Expand Down Expand Up @@ -169,7 +174,6 @@ def extract_hiddens(
ds_names = cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."


model = instantiate_model_with_devices(
cfg=cfg, device_config=devices, is_verbose=is_verbose
)
Expand Down Expand Up @@ -202,13 +206,25 @@ def extract_hiddens(
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1))

global_max_examples = cfg.max_examples[0 if split_type == "train" else 1]

# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
max_length = assert_type(int, tokenizer.model_max_length)

# Keep track of the number of examples we've yielded so far. We can't do something
# clean like `islice` the dataset, because we skip examples that are too long, and
# we can't predict how many of those there will be.
num_yielded = 0

# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

for example in islice(prompt_ds, max_examples):
for example in prompt_ds:
# Check if we've yielded enough examples
if num_yielded >= max_examples:
break

num_variants = len(example["prompts"])
num_choices = len(example["prompts"][0])

Expand Down Expand Up @@ -240,19 +256,14 @@ def extract_hiddens(

# Only feed question, not the answer, to the encoder for enc-dec models
target = choice["answer"] if is_enc_dec else None

# Record the EXACT question we fed to the model
variant_questions.append(text)
encoding = tokenizer(
text,
# Keep [CLS] and [SEP] for BERT-style models
add_special_tokens=True,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
truncation=True,
).to(first_device)
input_ids = assert_type(Tensor, encoding.input_ids)

if is_enc_dec:
answer = assert_type(Tensor, encoding.labels)
else:
Expand All @@ -263,11 +274,14 @@ def extract_hiddens(
return_tensors="pt",
).to(first_device)
answer = assert_type(Tensor, encoding2.input_ids)

input_ids = torch.cat([input_ids, answer], dim=-1)
if max_len := tokenizer.model_max_length:
cur_len = input_ids.shape[-1]
input_ids = input_ids[..., -min(cur_len, max_len) :]

# If this input is too long, skip it
if input_ids.shape[-1] > max_length:
break
else:
# Record the EXACT question we fed to the model
variant_questions.append(text)

# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids.long())
Expand Down Expand Up @@ -307,10 +321,20 @@ def extract_hiddens(
raise ValueError(f"Invalid token_loc: {cfg.token_loc}")

for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)
hidden_dict[f"hidden_{layer_idx}"][i, j] = float_to_int16(hidden)

# We skipped a pseudolabel because it was too long; break out of this whole
# example and move on to the next one
if len(variant_questions) != num_choices:
break

# Usual case: we have the expected number of pseudolabels
text_questions.append(variant_questions)

# We skipped a variant because it was too long; move on to the next example
if len(text_questions) != num_variants:
continue

out_record: dict[str, Any] = dict(
label=example["label"],
variant_ids=example["template_names"],
Expand All @@ -320,6 +344,7 @@ def extract_hiddens(
if has_lm_preds:
out_record["model_logits"] = lm_logits

num_yielded += 1
yield out_record


Expand Down
13 changes: 8 additions & 5 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.multiprocessing as mp
import yaml
from simple_parsing.helpers import Serializable, field
from simple_parsing.helpers.serialization import save
from torch import Tensor
from tqdm import tqdm

Expand All @@ -37,12 +38,14 @@ class Run(ABC, Serializable):
"""Directory to save results to. If None, a directory will be created
automatically."""

datasets: list[DatasetDictWithName] = field(default_factory=list, init=False)
datasets: list[DatasetDictWithName] = field(
default_factory=list, init=False, to_dict=False
)
"""Datasets containing hidden states and labels for each layer."""

concatenated_layer_offset: int = 0
debug: bool = False
min_gpu_mem: int | None = None
min_gpu_mem: int | None = None # in bytes
num_gpus: int = -1
out_dir: Path | None = None
disable_cache: bool = field(default=False, to_dict=False)
Expand Down Expand Up @@ -78,9 +81,9 @@ def execute(
print(f"Output directory at \033[1m{self.out_dir}\033[0m")
self.out_dir.mkdir(parents=True, exist_ok=True)

path = self.out_dir / "cfg.yaml"
with open(path, "w") as f:
self.dump_yaml(f)
# save_dc_types really ought to be the default... We simply can't load
# properly without this flag enabled.
save(self, self.out_dir / "cfg.yaml", save_dc_types=True)

path = self.out_dir / "fingerprints.yaml"
with open(path, "w") as meta_f:
Expand Down
19 changes: 17 additions & 2 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, Optional, cast

import torch
Expand Down Expand Up @@ -59,7 +60,6 @@ class CcsReporterConfig(ReporterConfig):
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
num_layers: int = 1
pre_ln: bool = False
seed: int = 42
supervised_weight: float = 0.0

lr: float = 1e-2
Expand All @@ -68,6 +68,10 @@ class CcsReporterConfig(ReporterConfig):
optimizer: Literal["adam", "lbfgs"] = "lbfgs"
weight_decay: float = 0.01

@classmethod
def reporter_class(cls) -> type[Reporter]:
return CcsReporter

def __post_init__(self):
self.loss_dict = parse_loss(self.loss)

Expand All @@ -94,6 +98,11 @@ def __init__(
):
super().__init__()
self.config = cfg
self.in_features = in_features

# Learnable Platt scaling parameters
self.bias = nn.Parameter(torch.zeros(1, device=device, dtype=dtype))
self.scale = nn.Parameter(torch.ones(1, device=device, dtype=dtype))

hidden_size = cfg.hidden_size or 4 * in_features // 3

Expand Down Expand Up @@ -239,7 +248,7 @@ def forward(self, x: Tensor) -> Tensor:

def raw_forward(self, x: Tensor) -> Tensor:
"""Apply the probe to the provided input, without normalization."""
return self.probe(x).squeeze(-1)
return self.probe(x).mul(self.scale).add(self.bias).squeeze(-1)

def loss(
self,
Expand Down Expand Up @@ -401,3 +410,9 @@ def closure():

optimizer.step(closure)
return float(loss)

def save(self, path: Path | str) -> None:
"""Save the reporter to a file."""
state = {k: v.cpu() for k, v in self.state_dict().items()}
state.update(in_features=self.in_features)
torch.save(state, path)
Loading

0 comments on commit ae23b59

Please sign in to comment.