Skip to content

Commit

Permalink
MAJOR UPDATE: Remove CCS, VINC (#292)
Browse files Browse the repository at this point in the history
* save hiddens to disk

* remove contrast pairs

* fix tests

* add LEACE to supervised

* add assertion for multi-dataset erasure

* add blank template for statements

* mvp working for llama

* inference server working with ids

* refactor extraction to use InferenceServer

* mvp with inference server

* fix caching

* don't load model when using cache

* add default template

* maybe unsqueeze

* gutted elk; updated tests

* save logprobs

* add balance and max_inlp_iter args

* extract lm predictions

* lm preds

* add encodings test, cleanup

* ignore type issue

* revisions from Nora's feedback; move output_hidden_states to model_kwargs, fix answer token being appended, fix viz, fix tqdm propagation

* cleanup

* re-fix tests

* fix save_logprobs

* fix layer sorting in logprobs.pt

* mark gpu tests

* test logprobs

* remove buggy viz test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AlexTMallen and pre-commit-ci[bot] authored Oct 23, 2023
1 parent 670eaec commit 70a3290
Show file tree
Hide file tree
Showing 38 changed files with 1,204 additions and 1,886 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ Our code is based on [PyTorch](http://pytorch.org)
and [Huggingface Transformers](https://huggingface.co/docs/transformers/index). We test the code on Python 3.10 and
3.11.

First install the package with `pip install -e .` in the root directory, or `pip install eleuther-elk` to install from PyPi. Use `pip install -e .[dev]` if you'd like to contribute to the project (see **Development** section below). This should install all the necessary dependencies.
First install the package with `pip install -e .` in the root directory, or `pip install -e .[dev]` if you'd like to
contribute to the project (see **Development** section below). This should install all the necessary dependencies.

To fit reporters for the HuggingFace model `model` and dataset `dataset`, just run:

Expand Down
1 change: 0 additions & 1 deletion comparison-sweeps
Submodule comparison-sweeps deleted from f4ed88
12 changes: 5 additions & 7 deletions elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from .extraction import Extract, extract_hiddens
from .training import EigenFitter, EigenFitterConfig
from .truncated_eigh import truncated_eigh
from .evaluation import Eval
from .extraction import Extract
from .training.train import Elicit

__all__ = [
"EigenFitter",
"EigenFitterConfig",
"extract_hiddens",
"Extract",
"truncated_eigh",
"Elicit",
"Eval",
]
22 changes: 10 additions & 12 deletions elk/debug_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,24 @@ def save_debug_log(datasets: list[DatasetDictWithName], out_dir: Path) -> None:
else:
train_split, val_split = select_train_val_splits(ds)

text_questions = ds[val_split][0]["text_questions"]
if len(ds[val_split]) == 0:
logging.warning(f"Val split '{val_split}' is empty!")
continue

texts = ds[val_split][0]["texts"]
template_ids = ds[val_split][0]["variant_ids"]
label = ds[val_split][0]["label"]
ds[val_split][0]["label"]

# log the train size and val size
if train_split is not None:
logging.info(f"Train size: {len(ds[train_split])}")
logging.info(f"Val size: {len(ds[val_split])}")

templates_text = f"{len(text_questions)} templates used:\n"
templates_text = f"{len(texts)} templates used:\n"
trailing_whitespace = False
for (text0, text1), id in zip(text_questions, template_ids):
templates_text += (
f'***---TEMPLATE "{id}"---***\n'
f"{'false' if label else 'true'}:\n"
f'"""{text0}"""\n'
f"{'true' if label else 'false'}:\n"
f'"""{text1}"""\n\n\n'
)
if text0[-1].isspace() or text1[-1].isspace():
for text, id in zip(texts, template_ids):
templates_text += f'***---TEMPLATE "{id}"---***\n' f'"""{text}"""\n'
if text[-1].isspace():
trailing_whitespace = True
if trailing_whitespace:
logging.warning(
Expand Down
84 changes: 48 additions & 36 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from simple_parsing.helpers import field

from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..metrics import evaluate_preds, get_logprobs
from ..run import Run
from ..utils import Color

Expand All @@ -17,7 +17,6 @@ class Eval(Run):
"""Full specification of a reporter evaluation run."""

source: Path = field(positional=True)
skip_supervised: bool = False

def __post_init__(self):
# Set our output directory before super().execute() does
Expand All @@ -31,55 +30,68 @@ def execute(self, highlight_color: Color = "cyan"):
@torch.inference_mode()
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
) -> tuple[dict[str, pd.DataFrame], dict]:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")

experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)
lr_dir = experiment_dir / "lr_models"
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

out_logprobs = defaultdict(dict)
row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items():
for ds_name, val_data in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
}
if self.save_logprobs:
out_logprobs[ds_name] = dict(
row_ids=val_data.row_ids.cpu(),
variant_ids=val_data.variant_ids,
texts=val_data.texts,
labels=val_data.labels.cpu(),
lm=dict(),
lr=dict(),
)

if val_lm_preds is not None:
for mode in ("none", "full"):
if val_data.lm_log_odds is not None:
if self.save_logprobs:
out_logprobs[ds_name]["lm"][mode] = get_logprobs(
val_data.lm_log_odds, mode
).cpu()
row_bufs["lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
**meta,
**evaluate_preds(
val_data.labels, val_data.lm_log_odds, mode
).to_dict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
}
)
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode] = dict()

for i, model in enumerate(lr_models):
model.eval()
val_log_odds = model(val_data.hiddens)
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = get_logprobs(
val_log_odds, mode
).cpu()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(
val_data.labels, val_log_odds, mode
).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs
9 changes: 6 additions & 3 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract, extract_hiddens
from .extraction import Extract, extract, tokenize_dataset
from .generator import _GeneratorBuilder, _GeneratorConfig
from .prompt_loading import load_prompts
from .inference_server import InferenceServer
from .prompt_loading import get_prompter, load_prompts

__all__ = [
"BalancedSampler",
"FewShotSampler",
"Extract",
"extract_hiddens",
"InferenceServer",
"extract",
"_GeneratorConfig",
"_GeneratorBuilder",
"load_prompts",
"get_prompter",
"tokenize_dataset",
]
Loading

0 comments on commit 70a3290

Please sign in to comment.