Skip to content

Commit

Permalink
Merge branch 'main' into mislead-replicate
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Mar 17, 2023
2 parents 34a197f + 69da185 commit 1302c76
Show file tree
Hide file tree
Showing 23 changed files with 1,324 additions and 535 deletions.
38 changes: 10 additions & 28 deletions .github/workflows/cpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ name: "Run CPU Tests"
on: "push"

jobs:
run-tests-python3_9:
runs-on: ubuntu-latest
run-tests:
strategy:
matrix:
python-versions: [ 3.9, "3.10", "3.11" ]
os: [ ubuntu-latest, macos-latest ]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Install Python
uses: actions/setup-python@v4
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: ${{ matrix.python-versions }}

- name: Upgrade Pip
run: python -m pip install --upgrade pip
Expand All @@ -21,26 +24,5 @@ jobs:
- name: Type Checking
uses: jakebailey/pyright-action@v1

- name: Run CPU Tests
run: pytest -m cpu

run-tests-python3_10:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Upgrade Pip
run: python -m pip install --upgrade pip

- name: Install Dependencies
run: pip install -e .[dev]

- name: Type Checking
uses: jakebailey/pyright-action@v1

- name: Run CPU Tests
run: pytest -m cpu
- name: Run normal tests, excluding GPU tests
run: pytest -m "not gpu"
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: flake8
args: ["--ignore=E203,F401,W503", --max-line-length=88]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
rev: v2.2.4
hooks:
- id: codespell
# The promptsource templates spuriously get flagged without this
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ To only extract the hidden states for the model `model` and the dataset `dataset
elk extract microsoft/deberta-v2-xxlarge-mnli imdb -o my_output_dir
```

The following will generate a CCS reporter instead of the Eigen reporter, which is the default.

```bash
elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --net ccs
```

## Development
Use `pip install pre-commit && pre-commit install` in the root folder before your first commit.

Expand Down
90 changes: 90 additions & 0 deletions elk/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from dataclasses import dataclass, field
from torch import Tensor
from typing import NamedTuple
import torch
import warnings


class CalibrationEstimate(NamedTuple):
ece: float
num_bins: int


@dataclass
class CalibrationError:
"""Monotonic Sweep Calibration Error for binary problems.
This method estimates the True Calibration Error (TCE) by searching for the largest
number of bins into which the data can be split that preserves the monotonicity
of the predicted confidence -> empirical accuracy mapping. We use equal mass bins
(quantiles) instead of equal width bins. Roelofs et al. (2020) show that this
estimator has especially low bias in simulations where the TCE is analytically
computable, and is hyperparameter-free (except for the type of norm used).
Paper: "Mitigating Bias in Calibration Error Estimation" by Roelofs et al. (2020)
Link: https://arxiv.org/abs/2012.08668
"""

labels: list[Tensor] = field(default_factory=list)
pred_probs: list[Tensor] = field(default_factory=list)

def update(self, labels: Tensor, probs: Tensor) -> "CalibrationError":
labels, probs = labels.detach().flatten(), probs.detach().flatten()
assert labels.shape == probs.shape
assert torch.is_floating_point(probs)

self.labels.append(probs)
self.pred_probs.append(labels)
return self

def compute(self, p: int = 2) -> CalibrationEstimate:
"""Compute the expected calibration error.
Args:
p: The norm to use for the calibration error. Defaults to 2 (Euclidean).
"""
labels = torch.cat(self.labels)
pred_probs = torch.cat(self.pred_probs)

n = len(pred_probs)
if n < 2:
raise ValueError("Not enough data to compute calibration error.")

# Sort the predictions and labels
pred_probs, indices = pred_probs.sort()
labels = labels[indices].float()

# Search for the largest number of bins which preserves monotonicity.
# Based on Algorithm 1 in Roelofs et al. (2020).
# Using a single bin is guaranteed to be monotonic, so we start there.
b_star, accs_star = 1, labels.mean().unsqueeze(0)
for b in range(2, n + 1):
# Split into (nearly) equal mass bins
freqs = torch.stack([h.mean() for h in labels.tensor_split(b)])

# This binning is not strictly monotonic, let's break
if not torch.all(freqs[1:] > freqs[:-1]):
break

elif not torch.all(freqs * (1 - freqs)):
warnings.warn(
"Calibration error estimate may be unreliable due to insufficient"
" data in some bins."
)
break

# Save the current binning, it's monotonic and may be the best one
else:
accs_star = freqs
b_star = b

# Split into (nearly) equal mass bins. They won't be exactly equal, so we
# still weight the bins by their size.
conf_bins = pred_probs.tensor_split(b_star)
w = torch.tensor([len(c) / n for c in conf_bins])

# See the definition of ECE_sweep in Equation 8 of Roelofs et al. (2020)
mean_confs = torch.stack([c.mean() for c in conf_bins])
ece = torch.sum(w * torch.abs(accs_star - mean_confs) ** p) ** (1 / p)

return CalibrationEstimate(float(ece), b_star)
100 changes: 62 additions & 38 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from simple_parsing.helpers import field, Serializable
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
BatchEncoding,
PreTrainedModel,
Expand Down Expand Up @@ -84,18 +83,27 @@ def get_arch_string(arch_suffix):
return arch_str
return None

suffixes = ['SequenceClassification', 'CausalLM', 'LMHeadModel', 'ConditionalGeneration']
suffixes = [
"SequenceClassification",
"CausalLM",
"LMHeadModel",
"ConditionalGeneration",
]

for suffix in suffixes:
supported_arch = get_arch_string(suffix)

if supported_arch is None:
continue

return getattr(transformers, supported_arch)

raise ValueError(f'{model_str} does not support any architectures in the list: {architectures}')
raise ValueError(
f"{model_str} does not support any architectures in the list: {architectures}"
)


@torch.no_grad()
def extract_hiddens(
cfg: ExtractionConfig,
*,
Expand Down Expand Up @@ -127,7 +135,11 @@ def extract_hiddens(

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
model = get_model_class(cfg.model).from_pretrained(cfg.model, torch_dtype="auto").to(device)
model = (
get_model_class(cfg.model)
.from_pretrained(cfg.model, torch_dtype="auto")
.to(device)
)
# TODO: Maybe also make this configurable?
# We want to make sure the answer is never truncated
tokenizer = AutoTokenizer.from_pretrained(cfg.model, truncation_side="left")
Expand Down Expand Up @@ -169,11 +181,12 @@ def tokenize(prompt: Prompt, idx: int, **kwargs):
# need to store locations of output tokens
# beginning and end of span are equal iff it's a special token
answer_indices = [
token_idx for token_idx, span in enumerate(tokenized['offset_mapping'][0])
token_idx
for token_idx, span in enumerate(tokenized["offset_mapping"][0])
if (answer_start <= span[0] and span[0] != span[1])
]

del tokenized['offset_mapping']
del tokenized["offset_mapping"]

return (tokenized, answer_indices)

Expand All @@ -187,7 +200,11 @@ def tokenize(prompt: Prompt, idx: int, **kwargs):
**kwargs,
).to(device)

answer_indices = [i for i, tkn in enumerate(tokenized['labels'][0]) if tkn not in tokenizer.all_special_ids]
answer_indices = [
i
for i, tkn in enumerate(tokenized["labels"][0])
if tkn not in tokenizer.all_special_ids
]

return (tokenized, answer_indices)

Expand Down Expand Up @@ -224,7 +241,9 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]:
}
variant_ids = [prompt.template_name for prompt in prompts]

logprobs_all = torch.empty((prompt_ds.num_variants, num_choices), device=device, dtype=torch.float16)
logprobs_all = torch.empty(
(prompt_ds.num_variants, num_choices), device=device, dtype=torch.float16
)

# Iterate over variants
for i, variant_inputs in enumerate(inputs):
Expand Down Expand Up @@ -254,23 +273,27 @@ def collate(prompts: list[Prompt]) -> list[list[BatchEncoding]]:

for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)

logprobs = outputs.logits.log_softmax(dim=-1)

if should_concat:
# offset predictions targets (target i=1 -> prediction i=0)
input_tokens = inpt['input_ids'][:, 1:]
input_tokens = inpt["input_ids"][:, 1:]
answer_indices = [idx - 1 for idx in answer_indices]

logprobs = torch.gather(logprobs[:, :-1], 2, input_tokens.unsqueeze(-1)).squeeze(-1)
logprobs = torch.gather(
logprobs[:, :-1], 2, input_tokens.unsqueeze(-1)
).squeeze(-1)

logprobs_all[i, j] = torch.sum(logprobs.squeeze(0)[answer_indices])

else:
# labels don't need offset
answer_tokens = inpt['labels']
answer_tokens = inpt["labels"]

logprobs = torch.gather(logprobs, 2, answer_tokens.unsqueeze(-1)).squeeze(-1)
logprobs = torch.gather(
logprobs, 2, answer_tokens.unsqueeze(-1)
).squeeze(-1)

logprobs_all[i, j] = torch.sum(logprobs.squeeze(0)[answer_indices])

Expand All @@ -291,39 +314,40 @@ def extract(cfg: ExtractionConfig, max_gpus: int = -1) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

def get_splits() -> SplitDict:
base_splits = assert_type(SplitDict, info.splits)
splits = set(base_splits) & {Split.TRAIN, Split.VALIDATION, Split.TEST}
if Split.VALIDATION in splits and Split.TEST in splits:
splits.remove(Split.TEST)
available_splits = assert_type(SplitDict, info.splits)
priorities = {
Split.TRAIN: 0,
Split.VALIDATION: 1,
Split.TEST: 2,
}
splits = sorted(available_splits, key=lambda k: priorities.get(k, 100))
assert len(splits) >= 2, "Must have train and val/test splits"

assert len(splits) == 2, "Must have train and val/test splits"
val_split = Split.VALIDATION if Split.VALIDATION in splits else Split.TEST
# Take the first two splits
splits = splits[:2]
print(f"Using '{splits[0]}' for training and '{splits[1]}' for validation")

# grab the max number of examples from the config for each split
limit = (
{
Split.TRAIN: cfg.prompts.max_examples[0],
val_split: cfg.prompts.max_examples[0]
if len(cfg.prompts.max_examples) == 1
else cfg.prompts.max_examples[1],
}
if cfg.prompts.max_examples
else {
Split.TRAIN: int(1e100),
val_split: int(1e100),
}
)
# Empty list means no limit
limit_list = cfg.prompts.max_examples
if not limit_list:
limit_list = [int(1e100)]

# Broadcast the limit to all splits
if len(limit_list) == 1:
limit_list *= len(splits)

limit = {k: v for k, v in zip(splits, limit_list)}
return SplitDict(
{
k: SplitInfo(
name=k,
num_examples=min(limit[k], v.num_examples),
dataset_name=v.dataset_name,
)
for k, v in base_splits.items()
for k, v in available_splits.items()
if k in splits
},
dataset_name=base_splits.dataset_name,
dataset_name=available_splits.dataset_name,
)

model_cfg = AutoConfig.from_pretrained(cfg.model)
Expand Down Expand Up @@ -352,7 +376,7 @@ def get_splits() -> SplitDict:
"logprobs": Array2D(
dtype="float32",
shape=(num_variants, 2),
)
),
}
devices = select_usable_devices(max_gpus)
builders = {
Expand Down
Loading

0 comments on commit 1302c76

Please sign in to comment.