From f57845acdbd682044d665b8379d56f10612aa3ed Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Fri, 10 Mar 2023 15:16:24 -0800 Subject: [PATCH 01/10] Support boolq inference (#121) * allow boolean label col; add boolq templates * remove templates without passage * fix check for label column type --- elk/extraction/prompt_dataset.py | 27 +- elk/parsing.py | 3 +- .../templates/boolq/templates.yaml | 256 ++++++++++++++++++ elk/utils/data_utils.py | 9 +- 4 files changed, 284 insertions(+), 11 deletions(-) create mode 100644 elk/promptsource/templates/boolq/templates.yaml diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py index 2b5d45e7..336d3a85 100644 --- a/elk/extraction/prompt_dataset.py +++ b/elk/extraction/prompt_dataset.py @@ -2,7 +2,7 @@ from ..promptsource import DatasetTemplates from ..utils import assert_type, compute_class_balance, infer_label_column, undersample from dataclasses import dataclass -from datasets import DatasetDict, load_dataset +from datasets import DatasetDict, load_dataset, ClassLabel, Value from numpy.typing import NDArray from random import Random from simple_parsing.helpers import field, Serializable @@ -126,10 +126,14 @@ def __init__( # Enforce class balance if needed if cfg.balance: - self.active_split = undersample(self.active_split, self.rng, label_col) + self.active_split = undersample( + self.active_split, self.rng, self.num_classes, label_col + ) self.class_fracs = np.ones(self.num_classes) / self.num_classes else: - class_sizes = compute_class_balance(self.active_split, label_col) + class_sizes = compute_class_balance( + self.active_split, self.num_classes, label_col + ) self.class_fracs: NDArray[np.floating] = class_sizes / class_sizes.sum() # We use stratified sampling to create few-shot prompts that are as balanced as @@ -238,6 +242,17 @@ def __len__(self): @property def num_classes(self) -> int: """Number of classes in the underlying dataset.""" - - # We piggyback on the ClassLabel feature type to get the number of classes - return self.active_split.features[self.label_column].num_classes + if isinstance(self.active_split.features[self.label_column], ClassLabel): + # We piggyback on the ClassLabel feature type to get the number of classes + return self.active_split.features[self.label_column].num_classes + elif ( + isinstance(self.active_split.features[self.label_column], Value) + and self.active_split.features[self.label_column].dtype == "bool" + ): + return 2 + else: + raise ValueError( + f"Can't infer number of classes from label column " + f"{self.label_column} of type " + f"{self.active_split.features[self.label_column]}" + ) diff --git a/elk/parsing.py b/elk/parsing.py index 465f2598..c40a8473 100644 --- a/elk/parsing.py +++ b/elk/parsing.py @@ -4,7 +4,8 @@ def parse_loss(terms: list[str]) -> dict[str, float]: """Parse the loss command line argument list into a dictionary.""" - assert len(terms) > 0, "No loss terms specified." + if len(terms) == 0: + return {"ccs_prompt_var": 1.0} loss_dict = dict() for term in terms: if term in loss_dict: diff --git a/elk/promptsource/templates/boolq/templates.yaml b/elk/promptsource/templates/boolq/templates.yaml new file mode 100644 index 00000000..c4d59f2f --- /dev/null +++ b/elk/promptsource/templates/boolq/templates.yaml @@ -0,0 +1,256 @@ +dataset: boolq +templates: + 0351d12a-e630-4d19-8b41-e199266e38f7: !Template + answer_choices: No ||| Yes + id: 0351d12a-e630-4d19-8b41-e199266e38f7 + jinja: '{{passage}} Given the above passage, {{ question }} ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: mallen_0 + reference: '' + 0f372fb1-795a-47b6-8ddf-c4fd1579e76a: !Template + answer_choices: incorrect ||| correct + id: 0f372fb1-795a-47b6-8ddf-c4fd1579e76a + jinja: 'According to expert sources, "{{passage}}", so next time someone someone asks "{{ question }}", you should say that is ||| {% if answer %} {{ answer_choices[1] }} + {% else %} {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: mallen_1 + reference: '' + 066474a5-1498-46b7-bfee-ac0c5160707f: !Template + answer_choices: Negatory. Over. ||| Affirmative. Over. + id: 066474a5-1498-46b7-bfee-ac0c5160707f + jinja: '"{{passage}}" Based on that, {{ question }}? Requesting an answer. Over. ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: mallen_2 + reference: '' + 066474a5-1498-46b7-bfee-ac0c51607070: !Template + answer_choices: No ||| Yes + id: 066474a5-1498-46b7-bfee-ac0c51607070 + jinja: '{{passage}}\n\nThe answer to "{{ question }}" is ||| {% if answer %} {{ answer_choices[1] }} + {% else %} {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: mallen_3 + reference: '' + 8351d12a-e630-4d19-8b41-e199266e38f7: !Template + answer_choices: False ||| True + id: 8351d12a-e630-4d19-8b41-e199266e38f7 + jinja: 'Passage: {{passage}} +After reading this passage, I have a question: {{question}}? True or +False? ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_0 + reference: '' + 8f372fb1-795a-47b6-8ddf-c4fd1579e76a: !Template + answer_choices: No ||| Yes + id: 8f372fb1-795a-47b6-8ddf-c4fd1579e76a + jinja: '{{ passage }} +Question: {{ question }} +Answer: ||| {% if answer %} {{ answer_choices[1] }} + {% else %} {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_1 + reference: '' + 866474a5-1498-46b7-bfee-ac0c5160707f: !Template + answer_choices: No ||| Yes + id: 866474a5-1498-46b7-bfee-ac0c5160707f + jinja: '{{ passage }} +Having read that, I wonder {{ question }}? ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_2 + reference: '' + 866474a5-1498-46b7-bfee-ac0c51607070: !Template + answer_choices: No ||| Yes + id: 866474a5-1498-46b7-bfee-ac0c51607070 + jinja: 'Text: {{passage}} +Answer the following yes/no question: {{question}}? Yes or no? ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_3 + reference: '' + 9351d12a-e630-4d19-8b41-e199266e38f7: !Template + answer_choices: No ||| Yes + id: 9351d12a-e630-4d19-8b41-e199266e38f7 + jinja: '{{ passage }} +Having read that, could you tell me {{ question }}? ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_4 + reference: '' + 9f372fb1-795a-47b6-8ddf-c4fd1579e76a: !Template + answer_choices: No ||| Yes + id: 9f372fb1-795a-47b6-8ddf-c4fd1579e76a + jinja: 'EXAM +1. Answer by yes or no. +Document: {{passage}} +Question: {{question}}?||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_5 + reference: '' + 966474a5-1498-46b7-bfee-ac0c5160707f: !Template + answer_choices: No ||| Yes + id: 966474a5-1498-46b7-bfee-ac0c5160707f + jinja: 'Based on the following passage, {{ question }}? {{ passage }} ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_6 + reference: '' + 966474a5-1498-46b7-bfee-ac0c51607070: !Template + answer_choices: False ||| True + id: 966474a5-1498-46b7-bfee-ac0c51607070 + jinja: 'Exercise: read the text and answer the question by True or False. +Text: {{passage}} +Question: {{question}}? ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_7 + reference: '' + a66474a5-1498-46b7-bfee-ac0c5160707f: !Template + answer_choices: No ||| Yes + id: a66474a5-1498-46b7-bfee-ac0c5160707f + jinja: '{{ passage }} +Based on the previous passage, {{ question }}? ||| + {% if answer %} + {{ answer_choices[1] }} + {% else %} + {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_8 + reference: '' + a66474a5-1498-46b7-bfee-ac0c51607070: !Template + answer_choices: False ||| True + id: a66474a5-1498-46b7-bfee-ac0c51607070 + jinja: '{{passage}} +Q: {{question}}? True or False? ||| {% if answer %} {{ answer_choices[1] }} + {% else %} {{ answer_choices[0] }} + {% endif %}' + metadata: !TemplateMetadata + choices_in_prompt: false + languages: + - en + metrics: + - Accuracy + original_task: true + name: sanh_9 + reference: '' diff --git a/elk/utils/data_utils.py b/elk/utils/data_utils.py index 6c21549e..41788c17 100644 --- a/elk/utils/data_utils.py +++ b/elk/utils/data_utils.py @@ -7,7 +7,9 @@ def compute_class_balance( - dataset: Dataset, label_column: Optional[str] = None + dataset: Dataset, + num_classes: int, + label_column: Optional[str] = None, ) -> np.ndarray: """Compute the class balance of a `Dataset`.""" @@ -18,7 +20,6 @@ def compute_class_balance( elif label_column not in features: raise ValueError(f"{name} has no column '{label_column}'") - num_classes = getattr(features[label_column], "num_classes", 0) class_sizes = np.bincount(dataset[label_column], minlength=num_classes) if not np.all(class_sizes > 0): @@ -71,11 +72,11 @@ def infer_label_column(features: Features) -> str: def undersample( - dataset: Dataset, rng: Random, label_column: Optional[str] = None + dataset: Dataset, rng: Random, num_classes: int, label_column: Optional[str] = None ) -> Dataset: """Undersample a `Dataset` to the smallest class size.""" label_column = label_column or infer_label_column(dataset.features) - class_sizes = compute_class_balance(dataset, label_column) + class_sizes = compute_class_balance(dataset, num_classes, label_column) smallest_size = class_sizes.min() # First group the active split by class From 2c6a10ae47e45c743ffc43ffb158e7bed72657d2 Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Wed, 15 Mar 2023 00:28:49 -0700 Subject: [PATCH 02/10] make fake example copy (#127) --- elk/extraction/prompt_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py index 336d3a85..c6039218 100644 --- a/elk/extraction/prompt_dataset.py +++ b/elk/extraction/prompt_dataset.py @@ -184,19 +184,20 @@ def __getitem__(self, index: int) -> list[Prompt]: ) example = self.active_split[index] + true_label = example[self.label_column] prompts = [] for template_name in template_names: template = self.prompter.templates[template_name] - true_label = example[self.label_column] answers = [] questions = set() for fake_label in range(self.num_classes): - example[self.label_column] = fake_label + fake_example = example.copy() + fake_example[self.label_column] = fake_label - q, a = template.apply(example) + q, a = template.apply(fake_example) answers.append(a) questions.add(q) From d50876e0382e8229fc3912b15a605395b05499f1 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Wed, 15 Mar 2023 18:29:21 +0000 Subject: [PATCH 03/10] add @torch.no_grad to extract_hiddens --- elk/extraction/extraction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index fbd69d53..5ca7b794 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -71,6 +71,7 @@ def __post_init__(self, layer_stride: int): self.layers = tuple(range(0, config.num_hidden_layers, layer_stride)) +@torch.no_grad() def extract_hiddens( cfg: ExtractionConfig, *, From 62bc2b4857b0ea44830090df9b2508f256796b0b Mon Sep 17 00:00:00 2001 From: James Chua <30519287+thejaminator@users.noreply.github.com> Date: Thu, 16 Mar 2023 04:44:24 +0800 Subject: [PATCH 04/10] remove marker for cpu, add marker for gpu (#128) Co-authored-by: James Chua --- .github/workflows/cpu_ci.yml | 8 ++++---- tests/test_classifier.py | 1 - tests/test_gpu_example.py | 10 ++++++++++ tests/test_math.py | 1 - 4 files changed, 14 insertions(+), 6 deletions(-) create mode 100644 tests/test_gpu_example.py diff --git a/.github/workflows/cpu_ci.yml b/.github/workflows/cpu_ci.yml index 8cbb3fc0..7190d1f8 100644 --- a/.github/workflows/cpu_ci.yml +++ b/.github/workflows/cpu_ci.yml @@ -21,8 +21,8 @@ jobs: - name: Type Checking uses: jakebailey/pyright-action@v1 - - name: Run CPU Tests - run: pytest -m cpu + - name: Run normal tests, excluding GPU tests + run: pytest -m "not gpu" run-tests-python3_10: runs-on: ubuntu-latest @@ -42,5 +42,5 @@ jobs: - name: Type Checking uses: jakebailey/pyright-action@v1 - - name: Run CPU Tests - run: pytest -m cpu + - name: Run normal tests, excluding GPU tests + run: pytest -m "not gpu" diff --git a/tests/test_classifier.py b/tests/test_classifier.py index 0b7a3b59..c55a858e 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -7,7 +7,6 @@ from elk.training.classifier import Classifier -@pytest.mark.cpu def test_classifier_roughly_same_sklearn(): input_dims: int = 10 # make a classification problem of 1000 samples with input_dims features diff --git a/tests/test_gpu_example.py b/tests/test_gpu_example.py new file mode 100644 index 00000000..2bff5481 --- /dev/null +++ b/tests/test_gpu_example.py @@ -0,0 +1,10 @@ +import pytest + + +@pytest.mark.gpu +def test_gpu_example(): + """Will only run if the `gpu` mark is specified + This is just an example test to show how to use the `gpu` mark + We'll need to implement a GPU runner in the CI for actual GPU tests + GPU tests can be run with `pytest -m gpu`""" + assert True diff --git a/tests/test_math.py b/tests/test_math.py index 5984371b..8cf3016b 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -12,7 +12,6 @@ # ...and the total sum of the floats st.integers(min_value=1, max_value=int(np.finfo(np.float32).max)), ) -@pytest.mark.cpu def test_stochastic_rounding(num_parts: int, total: int): # Randomly sample the breakdown of the total into floats rng = np.random.default_rng(42) From 026af7a5da3f7efc1785c74211e4778c86137666 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Wed, 15 Mar 2023 15:23:17 -0700 Subject: [PATCH 05/10] EigenReporter and VINC algorithm (#124) * Added error message for prompt-based loss and num_variants=1 * Added num_variants and ccs_prompt_var error message * changed prompt_var "Only one variant provided. Prompt variance loss will equal CCS loss." string to be accurate * changed default loss to ccs * Draft commit * Break Reporter into CcsReporter and EigenReporter * Fix transpose bug * Auto choose solver for device * Initial support for streaming VINC * Tests fr streaming VINC * Fix CcsReporter type check bug * Add fit_streaming * Platt scaling * Platt scaling by default * cleanup eigen_reporter * rename contrastive_cov * fix duplicate "intracluster_cov_M2" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add --net to readme * Update README.md * Update README.md * rename EigenReporter attributes in test_eigen_reporter.py * Fix warning in Classifier test * Flip sign on the 'loss' returned by EigenReporter.fit * Merge platt_scale into EigenReporter.fit --------- Co-authored-by: Benjamin Co-authored-by: Alex Mallen Co-authored-by: Walter Laurito Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 6 + elk/math_util.py | 37 ++++ elk/training/__init__.py | 2 + elk/training/ccs_reporter.py | 319 +++++++++++++++++++++++++++++++ elk/training/eigen_reporter.py | 233 +++++++++++++++++++++++ elk/training/losses.py | 6 +- elk/training/reporter.py | 334 +++++---------------------------- elk/training/train.py | 39 ++-- tests/test_classifier.py | 19 +- tests/test_eigen_reporter.py | 39 ++++ tests/test_math.py | 10 +- 11 files changed, 720 insertions(+), 324 deletions(-) create mode 100644 elk/training/ccs_reporter.py create mode 100644 elk/training/eigen_reporter.py create mode 100644 tests/test_eigen_reporter.py diff --git a/README.md b/README.md index 614fcdfc..3f2e9973 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,12 @@ To only extract the hidden states for the model `model` and the dataset `dataset elk extract microsoft/deberta-v2-xxlarge-mnli imdb -o my_output_dir ``` +The following will generate a CCS reporter instead of the Eigen reporter, which is the default. + +```bash +elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --net ccs +``` + ## Development Use `pip install pre-commit && pre-commit install` in the root folder before your first commit. diff --git a/elk/math_util.py b/elk/math_util.py index 85402d3c..6ab7ef3f 100644 --- a/elk/math_util.py +++ b/elk/math_util.py @@ -1,5 +1,42 @@ +from torch import Tensor +from typing import Optional import math import random +import torch + + +@torch.jit.script +def batch_cov(x: Tensor) -> Tensor: + """Compute a batch of covariance matrices. + + Args: + x: A tensor of shape [..., n, d]. + + Returns: + A tensor of shape [..., d, d]. + """ + x_ = x - x.mean(dim=-2, keepdim=True) + return x_.mT @ x_ / x_.shape[-2] + + +@torch.jit.script +def cov_mean_fused(x: Tensor) -> Tensor: + """Compute the mean of the covariance matrices of a batch of data matrices. + + The computation is done in a memory-efficient way, without materializing all + the covariance matrices in VRAM. + + Args: + x: A tensor of shape [batch, n, d]. + + Returns: + A tensor of shape [d, d]. + """ + b, n, d = x.shape + + x_ = x - x.mean(dim=1, keepdim=True) + x_ = x_.reshape(-1, d) + return x_.mT @ x_ / (b * n) def stochastic_round_constrained(x: list[float], rng: random.Random) -> list[int]: diff --git a/elk/training/__init__.py b/elk/training/__init__.py index 292bac59..baf33b51 100644 --- a/elk/training/__init__.py +++ b/elk/training/__init__.py @@ -1,2 +1,4 @@ +from .ccs_reporter import CcsReporter, CcsReporterConfig +from .eigen_reporter import EigenReporter, EigenReporterConfig from .reporter import OptimConfig, Reporter, ReporterConfig from .train import RunConfig diff --git a/elk/training/ccs_reporter.py b/elk/training/ccs_reporter.py new file mode 100644 index 00000000..54fc2292 --- /dev/null +++ b/elk/training/ccs_reporter.py @@ -0,0 +1,319 @@ +"""An ELK reporter network.""" + +from ..parsing import parse_loss +from ..utils.typing import assert_type +from .losses import LOSSES +from .reporter import Reporter, ReporterConfig +from copy import deepcopy +from dataclasses import dataclass, field +from torch import Tensor +from torch.nn.functional import binary_cross_entropy as bce +from typing import cast, Literal, NamedTuple, Optional +import math +import torch +import torch.nn as nn + + +@dataclass +class CcsReporterConfig(ReporterConfig): + """ + Args: + activation: The activation function to use. Defaults to GELU. + bias: Whether to use a bias term in the linear layers. Defaults to True. + hidden_size: The number of hidden units in the MLP. Defaults to None. + By default, use an MLP expansion ratio of 4/3. This ratio is used by + Tucker et al. (2022) in their 3-layer + MLP probes. We could also use a ratio of 4, imitating transformer FFNs, + but this seems to lead to excessively large MLPs when num_layers > 2. + init: The initialization scheme to use. Defaults to "zero". + loss: The loss function to use. list of strings, each of the form + "coef*name", where coef is a float and name is one of the keys in + `elk.training.losses.LOSSES`. + Example: --loss 1.0*consistency_squared 0.5*prompt_var + corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. + Defaults to "ccs_prompt_var". + num_layers: The number of layers in the MLP. Defaults to 1. + pre_ln: Whether to include a LayerNorm module before the first linear + layer. Defaults to False. + supervised_weight: The weight of the supervised loss. Defaults to 0.0. + + lr: The learning rate to use. Ignored when `optimizer` is `"lbfgs"`. + Defaults to 1e-2. + num_epochs: The number of epochs to train for. Defaults to 1000. + num_tries: The number of times to try training the reporter. Defaults to 10. + optimizer: The optimizer to use. Defaults to "adam". + weight_decay: The weight decay or L2 penalty to use. Defaults to 0.01. + """ + + activation: Literal["gelu", "relu", "swish"] = "gelu" + bias: bool = True + hidden_size: Optional[int] = None + init: Literal["default", "pca", "spherical", "zero"] = "default" + loss: list[str] = field(default_factory=lambda: ["ccs"]) + loss_dict: dict[str, float] = field(default_factory=dict, init=False) + num_layers: int = 1 + pre_ln: bool = False + seed: int = 42 + supervised_weight: float = 0.0 + + lr: float = 1e-2 + num_epochs: int = 1000 + num_tries: int = 10 + optimizer: Literal["adam", "lbfgs"] = "lbfgs" + weight_decay: float = 0.01 + + def __post_init__(self): + self.loss_dict = parse_loss(self.loss) + + # standardize the loss field + self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()] + + +class CcsReporter(Reporter): + """An ELK reporter network. + + Args: + in_features: The number of input features. + cfg: The reporter configuration. + """ + + config: CcsReporterConfig + + def __init__( + self, + in_features: int, + cfg: CcsReporterConfig, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__(in_features, cfg, device=device, dtype=dtype) + + hidden_size = cfg.hidden_size or 4 * in_features // 3 + + self.probe = nn.Sequential( + nn.Linear( + in_features, + 1 if cfg.num_layers < 2 else hidden_size, + bias=cfg.bias, + device=device, + ), + ) + if cfg.pre_ln: + self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) + + act_cls = { + "gelu": nn.GELU, + "relu": nn.ReLU, + "swish": nn.SiLU, + }[cfg.activation] + + for i in range(1, cfg.num_layers): + self.probe.append(act_cls()) + self.probe.append( + nn.Linear( + hidden_size, + 1 if i == cfg.num_layers - 1 else hidden_size, + bias=cfg.bias, + device=device, + ) + ) + + def unsupervised_loss(self, logit0: Tensor, logit1: Tensor) -> Tensor: + loss = sum( + LOSSES[name](logit0, logit1, coef) + for name, coef in self.config.loss_dict.items() + ) + return assert_type(Tensor, loss) + + def reset_parameters(self): + """Reset the parameters of the probe. + + If init is "spherical", use the spherical initialization scheme. + If init is "default", use the default PyTorch initialization scheme for + nn.Linear (Kaiming uniform). + If init is "zero", initialize all parameters to zero. + """ + if self.config.init == "spherical": + # Mathematically equivalent to the unusual initialization scheme used in + # the original paper. They sample a Gaussian vector of dim in_features + 1, + # normalize to the unit sphere, then add an extra all-ones dimension to the + # input and compute the inner product. Here, we use nn.Linear with an + # explicit bias term, but use the same initialization. + assert len(self.probe) == 1, "Only linear probes can use spherical init" + probe = cast(nn.Linear, self.probe[0]) # Pylance gets the type wrong here + + theta = torch.randn(1, probe.in_features + 1, device=probe.weight.device) + theta /= theta.norm() + probe.weight.data = theta[:, :-1] + probe.bias.data = theta[:, -1] + + elif self.config.init == "default": + for layer in self.probe: + if isinstance(layer, nn.Linear): + layer.reset_parameters() + + elif self.config.init == "zero": + for param in self.parameters(): + param.data.zero_() + elif self.config.init != "pca": + raise ValueError(f"Unknown init: {self.config.init}") + + def forward(self, x: Tensor) -> Tensor: + """Return the raw score output of the probe on `x`.""" + return self.probe(x).squeeze(-1) + + def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: + return 0.5 * (self(x_pos).sigmoid() + (1 - self(x_neg).sigmoid())) + + def loss( + self, + logit0: Tensor, + logit1: Tensor, + labels: Optional[Tensor] = None, + ) -> Tensor: + """Return the loss of the reporter on the contrast pair (x0, x1). + + Args: + logit0: The raw score output of the reporter on x0. + logit1: The raw score output of the reporter on x1. + labels: The labels of the contrast pair. Defaults to None. + + Returns: + loss: The loss of the reporter on the contrast pair (x0, x1). + + Raises: + ValueError: If `supervised_weight > 0` but `labels` is None. + """ + loss = self.unsupervised_loss(logit0, logit1) + + # If labels are provided, use them to compute a supervised loss + if labels is not None: + num_labels = len(labels) + assert num_labels <= len(logit0), "Too many labels provided" + p0 = logit0[:num_labels].sigmoid() + p1 = logit1[:num_labels].sigmoid() + + alpha = self.config.supervised_weight + preds = p0.add(1 - p1).mul(0.5).squeeze(-1) + bce_loss = bce(preds, labels.type_as(preds)) + loss = alpha * bce_loss + (1 - alpha) * loss + + elif self.config.supervised_weight > 0: + raise ValueError( + "Supervised weight > 0 but no labels provided to compute loss" + ) + + return loss + + def fit( + self, + x_pos: Tensor, + x_neg: Tensor, + labels: Optional[Tensor] = None, + ) -> float: + """Fit the probe to the contrast pair (x0, x1). + + Args: + contrast_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the + contrastive representations. + labels: The labels of the contrast pair. Defaults to None. + + Returns: + best_loss: The best loss obtained. + + Raises: + ValueError: If `optimizer` is not "adam" or "lbfgs". + RuntimeError: If the best loss is not finite. + """ + # TODO: Implement normalization here to fix issue #96 + # self.update(x_pos, x_neg) + + # Record the best acc, loss, and params found so far + best_loss = torch.inf + best_state: dict[str, Tensor] = {} # State dict of the best run + + for i in range(self.config.num_tries): + self.reset_parameters() + + # This is sort of inefficient but whatever + if self.config.init == "pca": + diffs = torch.flatten(x_pos - x_neg, 0, 1) + _, __, V = torch.pca_lowrank(diffs, q=i + 1) + self.probe[0].weight.data = V[:, -1, None].T + + if self.config.optimizer == "lbfgs": + loss = self.train_loop_lbfgs(x_pos, x_neg, labels) + elif self.config.optimizer == "adam": + loss = self.train_loop_adam(x_pos, x_neg, labels) + else: + raise ValueError(f"Optimizer {self.config.optimizer} is not supported") + + if loss < best_loss: + best_loss = loss + best_state = deepcopy(self.state_dict()) + + if not math.isfinite(best_loss): + raise RuntimeError("Got NaN/infinite loss during training") + + self.load_state_dict(best_state) + return best_loss + + def train_loop_adam( + self, + x_pos: Tensor, + x_neg: Tensor, + labels: Optional[Tensor] = None, + ) -> float: + """Adam train loop, returning the final loss. Modifies params in-place.""" + + optimizer = torch.optim.AdamW( + self.parameters(), lr=self.config.lr, weight_decay=self.config.weight_decay + ) + + loss = torch.inf + for _ in range(self.config.num_epochs): + optimizer.zero_grad() + + loss = self.loss(self(x_pos), self(x_neg), labels) + loss.backward() + optimizer.step() + + return float(loss) + + def train_loop_lbfgs( + self, + x_pos: Tensor, + x_neg: Tensor, + labels: Optional[Tensor] = None, + ) -> float: + """LBFGS train loop, returning the final loss. Modifies params in-place.""" + + optimizer = torch.optim.LBFGS( + self.parameters(), + line_search_fn="strong_wolfe", + max_iter=self.config.num_epochs, + tolerance_change=torch.finfo(x_pos.dtype).eps, + tolerance_grad=torch.finfo(x_pos.dtype).eps, + ) + # Raw unsupervised loss, WITHOUT regularization + loss = torch.inf + + def closure(): + nonlocal loss + optimizer.zero_grad() + + loss = self.loss(self(x_pos), self(x_neg), labels) + regularizer = 0.0 + + # We explicitly add L2 regularization to the loss, since LBFGS + # doesn't have a weight_decay parameter + for param in self.parameters(): + regularizer += self.config.weight_decay * param.norm() ** 2 / 2 + + regularized = loss + regularizer + regularized.backward() + + return float(regularized) + + optimizer.step(closure) + return float(loss) diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py new file mode 100644 index 00000000..3218ea70 --- /dev/null +++ b/elk/training/eigen_reporter.py @@ -0,0 +1,233 @@ +"""An ELK reporter network.""" + +from ..math_util import cov_mean_fused +from .reporter import EvalResult, Reporter, ReporterConfig +from copy import deepcopy +from dataclasses import dataclass +from torch import nn, optim, Tensor +from typing import Optional, Sequence +import torch + + +@dataclass +class EigenReporterConfig(ReporterConfig): + """Configuration for an EigenReporter. + + Args: + var_weight: The weight of the variance term in the loss. + inv_weight: The weight of the invariance term in the loss. + neg_cov_weight: The weight of the negative covariance term in the loss. + num_heads: The number of reporter heads to fit. In other words, the number + of eigenvectors to compute from the VINC matrix. + """ + + var_weight: float = 1.0 + inv_weight: float = 5.0 + neg_cov_weight: float = 5.0 + + num_heads: int = 1 + + +class EigenReporter(Reporter): + """A linear reporter whose weights are computed via eigendecomposition. + + Args: + in_features: The number of input features. + cfg: The reporter configuration. + + Attributes: + config: The reporter configuration. + intercluster_cov_M2: The running sum of the covariance matrices of the + centroids of the positive and negative clusters. + intracluster_cov: The running mean of the covariance matrices within each + cluster. This doesn't need to be a running sum because it's doesn't use + Welford's algorithm. + contrastive_xcov_M2: The running sum of the cross-covariance between the + centroids of the positive and negative clusters. + n: The running sum of the number of samples in the positive and negative + clusters. + weight: The reporter weight matrix. Guaranteed to always be orthogonal, and + the columns are sorted in descending order of eigenvalue magnitude. + """ + + config: EigenReporterConfig + + intercluster_cov_M2: Tensor # variance + intracluster_cov: Tensor # invariance + contrastive_xcov_M2: Tensor # negative covariance + n: Tensor + weight: Tensor + + def __init__( + self, + in_features: int, + cfg: EigenReporterConfig, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__(in_features, cfg, device=device, dtype=dtype) + + # Learnable Platt scaling parameters + self.bias = nn.Parameter(torch.zeros(cfg.num_heads, device=device, dtype=dtype)) + self.scale = nn.Parameter(torch.ones(cfg.num_heads, device=device, dtype=dtype)) + + self.register_buffer( + "contrastive_xcov_M2", + torch.zeros(in_features, in_features, device=device, dtype=dtype), + ) + self.register_buffer( + "intercluster_cov_M2", + torch.zeros(in_features, in_features, device=device, dtype=dtype), + ) + self.register_buffer( + "intracluster_cov", + torch.zeros(in_features, in_features, device=device, dtype=dtype), + ) + self.register_buffer( + "weight", + torch.zeros(cfg.num_heads, in_features, device=device, dtype=dtype), + ) + + def forward(self, x: Tensor) -> Tensor: + """Return the predicted log odds on input `x`.""" + raw_scores = x @ self.weight.mT + return raw_scores.mul(self.scale).add(self.bias).squeeze(-1) + + def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: + """Return the predicted log odds on the contrast pair `(x_pos, x_neg)`.""" + return 0.5 * (self(x_pos) - self(x_neg)) + + @property + def contrastive_xcov(self) -> Tensor: + return self.contrastive_xcov_M2 / self.n + + @property + def intercluster_cov(self) -> Tensor: + return self.intercluster_cov_M2 / self.n + + def clear(self) -> None: + """Clear the running statistics of the reporter.""" + self.contrastive_xcov_M2.zero_() + self.intracluster_cov.zero_() + self.intercluster_cov_M2.zero_() + self.n.zero_() + + @torch.no_grad() + def update(self, x_pos: Tensor, x_neg: Tensor) -> None: + # Sanity checks + assert x_pos.ndim == 3, "x_pos must be of shape [batch, num_variants, d]" + assert x_pos.shape == x_neg.shape, "x_pos and x_neg must have the same shape" + + # Average across variants inside each cluster, computing the centroids. + pos_centroids, neg_centroids = x_pos.mean(1), x_neg.mean(1) + + # We don't actually call super because we need access to the earlier estimate + # of the population mean in order to update (cross-)covariances properly + # super().update(x_pos, x_neg) + + sample_n = pos_centroids.shape[0] + self.n += sample_n + + # Update the running means; super().update() does this usually + neg_delta = neg_centroids - self.neg_mean + pos_delta = pos_centroids - self.pos_mean + self.neg_mean += neg_delta.sum(dim=0) / self.n + self.pos_mean += pos_delta.sum(dim=0) / self.n + + # *** Variance (inter-cluster) *** + # See code at https://bit.ly/3YC9BhH, as well as "Welford's online algorithm" + # in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. + # Post-mean update deltas are used to update the (co)variance + neg_delta2 = neg_centroids - self.neg_mean # [n, d] + pos_delta2 = pos_centroids - self.pos_mean # [n, d] + self.intercluster_cov_M2.addmm_(neg_delta.mT, neg_delta2) + self.intercluster_cov_M2.addmm_(pos_delta.mT, pos_delta2) + + # *** Invariance (intra-cluster) *** + # This is just a standard online *mean* update, since we're computing the + # mean of covariance matrices, not the covariance matrix of means. + sample_invar = cov_mean_fused(x_pos) + cov_mean_fused(x_neg) + self.intracluster_cov += (sample_n / self.n) * ( + sample_invar - self.intracluster_cov + ) + + # *** Negative covariance *** + self.contrastive_xcov_M2.addmm_(neg_delta.mT, pos_delta2) + self.contrastive_xcov_M2.addmm_(pos_delta.mT, neg_delta2) + + def fit_streaming(self, warm_start: bool = False) -> float: + """Fit the probe using the current streaming statistics.""" + A = ( + self.config.var_weight * self.intercluster_cov + - self.config.inv_weight * self.intracluster_cov + - self.config.neg_cov_weight * self.contrastive_xcov + ) + + # Use SciPy's sparse eigensolver for CPU tensors. This is a frontend to ARPACK, + # which uses the Lanczos method under the hood. + if A.device.type == "cpu": + from scipy.sparse.linalg import eigsh + + v0 = self.weight.T.numpy() if warm_start else None + + # We use "LA" (largest algebraic) instead of "LM" (largest magnitude) to + # ensure that the eigenvalue is positive and not a large negative one + L, Q = eigsh(A.numpy(), k=self.config.num_heads, v0=v0, which="LA") + self.weight.data = torch.from_numpy(Q).T + else: + L, Q = torch.linalg.eigh(A) + self.weight.data = Q[:, -self.config.num_heads :].T + + return -float(L[-1]) + + def fit( + self, + x_pos: Tensor, + x_neg: Tensor, + labels: Optional[Tensor] = None, + *, + platt_scale: bool = True, + ) -> float: + """Fit the probe to the contrast pair (x_pos, x_neg). + + Args: + x_pos: The positive examples. + x_neg: The negative examples. + labels: The ground truth labels if available. + platt_scale: Whether to fit the scale and bias terms to data with LBFGS. + This is only used if labels are available. + + Returns: + loss: Negative eigenvalue associated with the VINC direction. + """ + assert x_pos.shape == x_neg.shape + self.update(x_pos, x_neg) + loss = self.fit_streaming() + if labels is not None and platt_scale: + self.platt_scale(labels, x_pos, x_neg) + + return loss + + def platt_scale( + self, labels: Tensor, x_pos: Tensor, x_neg: Tensor, max_iter: int = 100 + ): + """Fit the scale and bias terms to data with LBFGS.""" + + opt = optim.LBFGS( + [self.bias, self.scale], + line_search_fn="strong_wolfe", + max_iter=max_iter, + tolerance_change=torch.finfo(x_pos.dtype).eps, + tolerance_grad=torch.finfo(x_pos.dtype).eps, + ) + labels = labels.repeat_interleave(x_pos.shape[1]).float() + + def closure(): + opt.zero_grad() + logits = self.predict(x_pos, x_neg).flatten() + loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) + + loss.backward() + return float(loss) + + opt.step(closure) diff --git a/elk/training/losses.py b/elk/training/losses.py index 11435b10..d91c1e79 100644 --- a/elk/training/losses.py +++ b/elk/training/losses.py @@ -1,7 +1,6 @@ """Loss functions for training reporters.""" from torch import Tensor -import math import torch import warnings from inspect import signature @@ -149,5 +148,8 @@ def prompt_var_loss(logit0: Tensor, logit1: Tensor, coef: float = 1.0) -> Tensor "Only one variant provided. Prompt variance loss will cause errors." ) p0, p1 = logit0.sigmoid(), logit1.sigmoid() - prompt_variance = p0.var(dim=-1).mean() + p1.var(dim=-1).mean() + + var0 = p0.var(dim=-1, unbiased=False).mean() + var1 = p1.var(dim=-1, unbiased=False).mean() + prompt_variance = var0 + var1 return coef * prompt_variance diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 543d8d21..2044b4f4 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -1,19 +1,14 @@ """An ELK reporter network.""" -from ..parsing import parse_loss -from ..utils.typing import assert_type from .classifier import Classifier -from .losses import LOSSES -from copy import deepcopy -from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from dataclasses import dataclass from einops import rearrange from pathlib import Path from simple_parsing.helpers import Serializable from sklearn.metrics import roc_auc_score from torch import Tensor -from torch.nn.functional import binary_cross_entropy as bce -from typing import cast, Literal, NamedTuple, Optional, Union -import math +from typing import Literal, NamedTuple, Optional, Union import torch import torch.nn as nn @@ -25,7 +20,6 @@ class EvalResult(NamedTuple): which contains the loss, accuracy, calibrated accuracy, and AUROC. """ - loss: float acc: float cal_acc: float auroc: float @@ -35,42 +29,10 @@ class EvalResult(NamedTuple): class ReporterConfig(Serializable): """ Args: - activation: The activation function to use. Defaults to GELU. - bias: Whether to use a bias term in the linear layers. Defaults to True. - hidden_size: The number of hidden units in the MLP. Defaults to None. - By default, use an MLP expansion ratio of 4/3. This ratio is used by - Tucker et al. (2022) in their 3-layer - MLP probes. We could also use a ratio of 4, imitating transformer FFNs, - but this seems to lead to excessively large MLPs when num_layers > 2. - init: The initialization scheme to use. Defaults to "zero". - loss: The loss function to use. list of strings, each of the form - "coef*name", where coef is a float and name is one of the keys in - `elk.training.losses.LOSSES`. - Example: --loss 1.0*consistency_squared 0.5*prompt_var - corresponds to the loss function 1.0*consistency_squared + 0.5*prompt_var. - Defaults to "ccs_prompt_var". - num_layers: The number of layers in the MLP. Defaults to 1. - pre_ln: Whether to include a LayerNorm module before the first linear - layer. Defaults to False. - supervised_weight: The weight of the supervised loss. Defaults to 0.0. + seed: The random seed to use. Defaults to 42. """ - activation: Literal["gelu", "relu", "swish"] = "gelu" - bias: bool = True - hidden_size: Optional[int] = None - init: Literal["default", "pca", "spherical", "zero"] = "default" - loss: list[str] = field(default_factory=lambda: ["ccs"]) - loss_dict: dict[str, float] = field(default_factory=dict, init=False) - num_layers: int = 1 - pre_ln: bool = False seed: int = 42 - supervised_weight: float = 0.0 - - def __post_init__(self): - self.loss_dict = parse_loss(self.loss) - - # standardize the loss field - self.loss = [f"{coef}*{name}" for name, coef in self.loss_dict.items()] @dataclass @@ -92,7 +54,7 @@ class OptimConfig(Serializable): weight_decay: float = 0.01 -class Reporter(nn.Module): +class Reporter(nn.Module, ABC): """An ELK reporter network. Args: @@ -100,45 +62,27 @@ class Reporter(nn.Module): cfg: The reporter configuration. """ + n: Tensor + neg_mean: Tensor + pos_mean: Tensor + def __init__( - self, in_features: int, cfg: ReporterConfig, device: Optional[str] = None + self, + in_features: int, + cfg: ReporterConfig, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, ): super().__init__() - hidden_size = cfg.hidden_size or 4 * in_features // 3 - - self.probe = nn.Sequential( - nn.Linear( - in_features, - 1 if cfg.num_layers < 2 else hidden_size, - bias=cfg.bias, - device=device, - ), + self.config = cfg + self.register_buffer("n", torch.zeros((), device=device, dtype=torch.long)) + self.register_buffer( + "neg_mean", torch.zeros(in_features, device=device, dtype=dtype) + ) + self.register_buffer( + "pos_mean", torch.zeros(in_features, device=device, dtype=dtype) ) - if cfg.pre_ln: - self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False)) - - act_cls = { - "gelu": nn.GELU, - "relu": nn.ReLU, - "swish": nn.SiLU, - }[cfg.activation] - - for i in range(1, cfg.num_layers): - self.probe.append(act_cls()) - self.probe.append( - nn.Linear( - hidden_size, - 1 if i == cfg.num_layers - 1 else hidden_size, - bias=cfg.bias, - device=device, - ) - ) - - self.init = cfg.init - self.device = device - self.loss_dict = cfg.loss_dict - self.supervised_weight = cfg.supervised_weight @classmethod def check_separability( @@ -185,46 +129,19 @@ def check_separability( ) return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) - def unsupervised_loss( - self, logit0: torch.Tensor, logit1: torch.Tensor - ) -> torch.Tensor: - loss = sum( - LOSSES[name](logit0, logit1, coef) for name, coef in self.loss_dict.items() - ) - return assert_type(torch.Tensor, loss) - def reset_parameters(self): - """Reset the parameters of the probe. + """Reset the parameters of the probe.""" - If init is "spherical", use the spherical initialization scheme. - If init is "default", use the default PyTorch initialization scheme for - nn.Linear (Kaiming uniform). - If init is "zero", initialize all parameters to zero. - """ - if self.init == "spherical": - # Mathematically equivalent to the unusual initialization scheme used in - # the original paper. They sample a Gaussian vector of dim in_features + 1, - # normalize to the unit sphere, then add an extra all-ones dimension to the - # input and compute the inner product. Here, we use nn.Linear with an - # explicit bias term, but use the same initialization. - assert len(self.probe) == 1, "Only linear probes can use spherical init" - probe = cast(nn.Linear, self.probe[0]) # Pylance gets the type wrong here - - theta = torch.randn(1, probe.in_features + 1, device=probe.weight.device) - theta /= theta.norm() - probe.weight.data = theta[:, :-1] - probe.bias.data = theta[:, -1] - - elif self.init == "default": - for layer in self.probe: - if isinstance(layer, nn.Linear): - layer.reset_parameters() - - elif self.init == "zero": - for param in self.parameters(): - param.data.zero_() - elif self.init != "pca": - raise ValueError(f"Unknown init: {self.init}") + @torch.no_grad() + def update(self, x_pos: Tensor, x_neg: Tensor) -> None: + """Update the running mean of the positive and negative examples.""" + + x_pos, x_neg = x_pos.flatten(0, -2), x_neg.flatten(0, -2) + self.n += x_pos.shape[0] + + # Update the running means + self.neg_mean += (x_neg.sum(dim=0) - self.neg_mean) / self.n + self.pos_mean += (x_pos.sum(dim=0) - self.pos_mean) / self.n # TODO: These methods will do something fancier in the future @classmethod @@ -236,124 +153,26 @@ def save(self, path: Union[Path, str]): # TODO: Save separate JSON and PT files for the reporter. torch.save(self, path) - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Return the raw score output of the probe on `x`.""" - return self.probe(x).squeeze(-1) - - def loss( - self, - logit0: torch.Tensor, - logit1: torch.Tensor, - labels: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Return the loss of the reporter on the contrast pair (x0, x1). - - Args: - logit0: The raw score output of the reporter on x0. - logit1: The raw score output of the reporter on x1. - labels: The labels of the contrast pair. Defaults to None. - - Returns: - loss: The loss of the reporter on the contrast pair (x0, x1). - - Raises: - ValueError: If `supervised_weight > 0` but `labels` is None. - """ - loss = self.unsupervised_loss(logit0, logit1) - - # If labels are provided, use them to compute a supervised loss - if labels is not None: - num_labels = len(labels) - assert num_labels <= len(logit0), "Too many labels provided" - p0 = logit0[:num_labels].sigmoid() - p1 = logit1[:num_labels].sigmoid() - - alpha = self.supervised_weight - preds = p0.add(1 - p1).mul(0.5).squeeze(-1) - bce_loss = bce(preds, labels.type_as(preds)) - loss = alpha * bce_loss + (1 - alpha) * loss - - elif self.supervised_weight > 0: - raise ValueError( - "Supervised weight > 0 but no labels provided to compute loss" - ) - - return loss - - def validate_data(self, data): - """Validate that the data's shape is valid.""" - assert len(data) == 2 and data[0].shape == data[1].shape - + @abstractmethod def fit( self, - contrast_pair: tuple[torch.Tensor, torch.Tensor], - labels: Optional[torch.Tensor] = None, - cfg: OptimConfig = OptimConfig(), + x_pos: Tensor, + x_neg: Tensor, + labels: Optional[Tensor] = None, ) -> float: - """Fit the probe to the contrast pair (x0, x1). + ... - Args: - contrast_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. - labels: The labels of the contrast pair. Defaults to None. - lr: The learning rate for Adam. Defaults to 1e-2. - num_epochs: The number of epochs to train for. Defaults to 1000. - num_tries: The number of times to repeat the procedure. Defaults to 10. - optimizer: The optimizer to use. Defaults to "adam". - verbose: Whether to print out information at each step. Defaults to False. - weight_decay: The weight decay for Adam. Defaults to 0.01. - - Returns: - best_loss: The best loss obtained. - - Raises: - ValueError: If `optimizer` is not "adam" or "lbfgs". - RuntimeError: If the best loss is not finite. - """ - self.validate_data(contrast_pair) - - # Record the best acc, loss, and params found so far - best_loss = torch.inf - best_state: dict[str, torch.Tensor] = {} # State dict of the best run - x0, x1 = contrast_pair - - for i in range(cfg.num_tries): - self.reset_parameters() - - # This is sort of inefficient but whatever - if self.init == "pca": - diffs = torch.flatten(x0 - x1, 0, 1) - _, __, V = torch.pca_lowrank(diffs, q=i + 1) - self.probe[0].weight.data = V[:, -1, None].T - - if cfg.optimizer == "lbfgs": - loss = self.train_loop_lbfgs(x0, x1, labels, cfg) - elif cfg.optimizer == "adam": - loss = self.train_loop_adam(x0, x1, labels, cfg) - else: - raise ValueError(f"Optimizer {cfg.optimizer} is not supported") - - if loss < best_loss: - best_loss = loss - best_state = deepcopy(self.state_dict()) - - if not math.isfinite(best_loss): - raise RuntimeError("Got NaN/infinite loss during training") - - self.load_state_dict(best_state) - return best_loss + @abstractmethod + def predict(self, x_pos: Tensor, x_neg: Tensor) -> Tensor: + """Pool the probe output on the contrast pair (x_pos, x_neg).""" @torch.no_grad() - def score( - self, - contrast_pair: tuple[torch.Tensor, torch.Tensor], - labels: torch.Tensor, - ) -> EvalResult: - """Score the probe on the contrast pair (x0, x1). + def score(self, labels: Tensor, x_pos: Tensor, x_neg: Tensor) -> EvalResult: + """Score the probe on the contrast pair (x_pos, x1). Args: - contrast_pair: A tuple of tensors, (x0, x1), where x0 and x1 are the - contrastive representations. + x_pos: The positive examples. + x_neg: The negative examples. labels: The labels of the contrast pair. Returns: @@ -361,11 +180,7 @@ def score( accuracy, and AUROC of the probe on the contrast pair (x0, x1). """ - self.validate_data(contrast_pair) - - logit0, logit1 = map(self, contrast_pair) - p0, p1 = logit0.sigmoid(), logit1.sigmoid() - pred_probs = 0.5 * (p0 + (1 - p1)) + pred_probs = self.predict(x_pos, x_neg) # Calibrated accuracy cal_thresh = pred_probs.float().quantile(labels.float().mean()) @@ -382,66 +197,7 @@ def score( raw_acc = raw_preds.flatten().eq(broadcast_labels).float().mean() return EvalResult( - loss=self.loss(logit0, logit1).item(), acc=torch.max(raw_acc, 1 - raw_acc).item(), cal_acc=torch.max(cal_acc, 1 - cal_acc).item(), auroc=max(auroc, 1 - auroc), ) - - def train_loop_adam( - self, - x0, - x1, - labels: Optional[torch.Tensor], - cfg: OptimConfig, - ) -> float: - """Adam train loop, returning the final loss. Modifies params in-place.""" - - optimizer = torch.optim.AdamW( - self.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay - ) - - loss = torch.inf - for _ in range(cfg.num_epochs): - optimizer.zero_grad() - - loss = self.loss(self(x0), self(x1), labels) - loss.backward() - optimizer.step() - - return float(loss) - - def train_loop_lbfgs( - self, x0, x1, labels: Optional[torch.Tensor], cfg: OptimConfig - ) -> float: - """LBFGS train loop, returning the final loss. Modifies params in-place.""" - - optimizer = torch.optim.LBFGS( - self.parameters(), - line_search_fn="strong_wolfe", - max_iter=cfg.num_epochs, - tolerance_change=torch.finfo(x0.dtype).eps, - tolerance_grad=torch.finfo(x0.dtype).eps, - ) - # Raw unsupervised loss, WITHOUT regularization - loss = torch.inf - - def closure(): - nonlocal loss - optimizer.zero_grad() - - loss = self.loss(self(x0), self(x1), labels) - regularizer = 0.0 - - # We explicitly add L2 regularization to the loss, since LBFGS - # doesn't have a weight_decay parameter - for param in self.parameters(): - regularizer += cfg.weight_decay * param.norm() ** 2 / 2 - - regularized = loss + regularizer - regularized.backward() - - return float(regularized) - - optimizer.step(closure) - return float(loss) diff --git a/elk/training/train.py b/elk/training/train.py index 5e21dcac..9ac7ad05 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -4,13 +4,15 @@ from ..files import elk_reporter_dir, memorably_named_dir from ..utils import assert_type, held_out_split, select_usable_devices, int16_to_float32 from .classifier import Classifier +from .ccs_reporter import CcsReporter, CcsReporterConfig +from .eigen_reporter import EigenReporter, EigenReporterConfig from .preprocessing import normalize from .reporter import OptimConfig, Reporter, ReporterConfig from dataclasses import dataclass from datasets import DatasetDict from functools import partial from pathlib import Path -from simple_parsing import Serializable +from simple_parsing import subgroups, Serializable from sklearn.metrics import accuracy_score, roc_auc_score from torch import Tensor from tqdm.auto import tqdm @@ -36,8 +38,10 @@ class RunConfig(Serializable): """ data: ExtractionConfig - net: ReporterConfig - optim: OptimConfig + net: ReporterConfig = subgroups( + {"ccs": CcsReporterConfig, "eigen": EigenReporterConfig}, default="eigen" + ) + optim: OptimConfig = OptimConfig() label_frac: float = 0.0 max_gpus: int = -1 @@ -93,17 +97,18 @@ def train_reporter( f"algorithm will not converge to a good solution." ) - reporter = Reporter(x0.shape[-1], cfg.net, device=device) - if cfg.label_frac: - num_labels = round(cfg.label_frac * len(train_labels)) - labels = train_labels[:num_labels] + if isinstance(cfg.net, CcsReporterConfig): + reporter = CcsReporter(x0.shape[-1], cfg.net, device=device) + elif isinstance(cfg.net, EigenReporterConfig): + reporter = EigenReporter(x0.shape[-1], cfg.net, device=device) else: - labels = None + raise ValueError(f"Unknown reporter config type: {type(cfg.net)}") - train_loss = reporter.fit((x0, x1), labels, cfg.optim) + train_loss = reporter.fit(x0, x1, train_labels) val_result = reporter.score( - (val_x0, val_x1), val_labels, + val_x0, + val_x1, ) lr_dir = out_dir / "lr_models" @@ -137,7 +142,7 @@ def train_reporter( lr_auroc = roc_auc_score(val_labels_aug, lr_preds) stats += [lr_auroc, lr_acc] - with open(lr_dir / f"layer_{layer}.pkl", "wb") as file: + with open(lr_dir / f"layer_{layer}.pt", "wb") as file: pickle.dump(lr_model, file) with open(reporter_dir / f"layer_{layer}.pt", "wb") as file: @@ -148,16 +153,6 @@ def train_reporter( def train(cfg: RunConfig, out_dir: Optional[Path] = None): # Extract the hidden states first if necessary - is_prompt_based_loss = ( - "ccs_prompt_var" in cfg.net.loss_dict.keys() - or "prompt_var_squared" in cfg.net.loss_dict.keys() - ) - if cfg.data.prompts.num_variants == 1 and is_prompt_based_loss: - raise ValueError( - "Loss functions ccs_prompt_var and prompt_var_squared " - "incompatible with --num_variants 1." - ) - ds = extract(cfg.data, max_gpus=cfg.max_gpus) if out_dir is None: @@ -174,7 +169,7 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): devices = select_usable_devices(cfg.max_gpus) num_devices = len(devices) - cols = ["layer", "pseudo_auroc", "train_loss", "loss", "acc", "cal_acc", "auroc"] + cols = ["layer", "pseudo_auroc", "train_loss", "acc", "cal_acc", "auroc"] if not cfg.skip_baseline: cols += ["lr_auroc", "lr_acc"] diff --git a/tests/test_classifier.py b/tests/test_classifier.py index c55a858e..bc88fa91 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -1,25 +1,24 @@ -import numpy as np -import pytest -import torch +from elk.training.classifier import Classifier +from elk.utils import assert_type from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression - -from elk.training.classifier import Classifier +import numpy as np +import torch def test_classifier_roughly_same_sklearn(): input_dims: int = 10 + # make a classification problem of 1000 samples with input_dims features - features: np.ndarray - truths: np.ndarray _features, _truths = make_classification( n_samples=1000, n_features=input_dims, random_state=0 ) # use float32 for the features so it's the same as the default dtype for torch - features = _features.astype(np.float32) - truths = _truths.astype(np.float32) + features = assert_type(np.ndarray, _features).astype(np.float32) + truths = assert_type(np.ndarray, _truths).astype(np.float32) + # train a logistic regression model on the data. No regularization - model = LogisticRegression(penalty="none", solver="lbfgs") + model = LogisticRegression(penalty=None, solver="lbfgs") # type: ignore[arg-type] model.fit(features, truths) # train a classifier on the data classifier = Classifier(input_dim=input_dims, device="cpu") diff --git a/tests/test_eigen_reporter.py b/tests/test_eigen_reporter.py new file mode 100644 index 00000000..4564afc7 --- /dev/null +++ b/tests/test_eigen_reporter.py @@ -0,0 +1,39 @@ +from elk.math_util import batch_cov, cov_mean_fused +from elk.training import EigenReporter, EigenReporterConfig +import torch + + +def test_eigen_reporter(): + cluster_size = 5 + hidden_size = 10 + num_clusters = 100 + + x_pos = torch.randn(num_clusters, cluster_size, hidden_size, dtype=torch.float64) + x_neg = torch.randn(num_clusters, cluster_size, hidden_size, dtype=torch.float64) + x_pos1, x_pos2 = x_pos.chunk(2, dim=0) + x_neg1, x_neg2 = x_neg.chunk(2, dim=0) + + reporter = EigenReporter(hidden_size, EigenReporterConfig(), dtype=torch.float64) + reporter.update(x_pos1, x_neg1) + reporter.update(x_pos2, x_neg2) + + # Check that the streaming mean is correct + pos_mu, neg_mu = x_pos.mean(dim=(0, 1)), x_neg.mean(dim=(0, 1)) + assert torch.allclose(reporter.pos_mean, pos_mu) + assert torch.allclose(reporter.neg_mean, neg_mu) + + # Check that the streaming covariance is correct + pos_centroids, neg_centroids = x_pos.mean(dim=1), x_neg.mean(dim=1) + expected_var = batch_cov(pos_centroids) + batch_cov(neg_centroids) + assert torch.allclose(reporter.intercluster_cov, expected_var) + + # Check that the streaming invariance (intra-cluster variance) is correct + expected_invariance = cov_mean_fused(x_pos) + cov_mean_fused(x_neg) + assert torch.allclose(reporter.intracluster_cov, expected_invariance) + + # Check that the streaming negative covariance is correct + cross_cov = (pos_centroids - pos_mu).mT @ (neg_centroids - neg_mu) / num_clusters + cross_cov = cross_cov + cross_cov.mT + assert torch.allclose(reporter.contrastive_xcov, cross_cov) + + assert reporter.n == num_clusters diff --git a/tests/test_math.py b/tests/test_math.py index 8cf3016b..728fe7fe 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -1,9 +1,17 @@ -from elk.math_util import stochastic_round_constrained +from elk.math_util import batch_cov, cov_mean_fused, stochastic_round_constrained from hypothesis import given, strategies as st from random import Random import math import numpy as np import pytest +import torch + + +def test_cov_mean_fused(): + X = torch.randn(10, 500, 100, dtype=torch.float64) + cov_gt = batch_cov(X).mean(dim=0) + cov_fused = cov_mean_fused(X) + assert torch.allclose(cov_gt, cov_fused) @given( From 58e3630a8c0d0f6c64b2a58ed488a08385fde343 Mon Sep 17 00:00:00 2001 From: James Chua <30519287+thejaminator@users.noreply.github.com> Date: Thu, 16 Mar 2023 13:38:27 +0800 Subject: [PATCH 06/10] Use a matrix of python versions for the pipeline (#129) * Use a matrix of python versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Python 3.11 * Fix typing issue on Python 3.11; prune deps * Fix dataclass bug on 3.11 --------- Co-authored-by: James Chua Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nora Belrose --- .github/workflows/cpu_ci.yml | 34 ++++++++------------------------ elk/extraction/prompt_dataset.py | 2 +- elk/training/reporter.py | 8 +++++--- elk/training/train.py | 4 ++-- pyproject.toml | 9 ++------- 5 files changed, 18 insertions(+), 39 deletions(-) diff --git a/.github/workflows/cpu_ci.yml b/.github/workflows/cpu_ci.yml index 7190d1f8..4cac3fb9 100644 --- a/.github/workflows/cpu_ci.yml +++ b/.github/workflows/cpu_ci.yml @@ -3,35 +3,17 @@ name: "Run CPU Tests" on: "push" jobs: - run-tests-python3_9: - runs-on: ubuntu-latest + run-tests: + strategy: + matrix: + python-versions: [ 3.9, "3.10", "3.11" ] + os: [ ubuntu-latest, macos-latest ] + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v3 - - name: Install Python - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: - python-version: "3.9" - - - name: Upgrade Pip - run: python -m pip install --upgrade pip - - - name: Install Dependencies - run: pip install -e .[dev] - - - name: Type Checking - uses: jakebailey/pyright-action@v1 - - - name: Run normal tests, excluding GPU tests - run: pytest -m "not gpu" - - run-tests-python3_10: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Install Python - uses: actions/setup-python@v4 - with: - python-version: "3.10" + python-version: ${{ matrix.python-versions }} - name: Upgrade Pip run: python -m pip install --upgrade pip diff --git a/elk/extraction/prompt_dataset.py b/elk/extraction/prompt_dataset.py index c6039218..0619de06 100644 --- a/elk/extraction/prompt_dataset.py +++ b/elk/extraction/prompt_dataset.py @@ -180,7 +180,7 @@ def __getitem__(self, index: int) -> list[Prompt]: """Get a list of prompts for a given predicate""" # get self.num_variants unique prompts from the template pool template_names = self.rng.sample( - self.prompter.templates.keys(), self.num_variants + list(self.prompter.templates), self.num_variants ) example = self.active_split[index] diff --git a/elk/training/reporter.py b/elk/training/reporter.py index 2044b4f4..f5963df8 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -3,7 +3,6 @@ from .classifier import Classifier from abc import ABC, abstractmethod from dataclasses import dataclass -from einops import rearrange from pathlib import Path from simple_parsing.helpers import Serializable from sklearn.metrics import roc_auc_score @@ -121,11 +120,14 @@ def check_separability( ).repeat_interleave(val_x0.shape[1]) pseudo_clf.fit( - rearrange(torch.cat([x0, x1]), "b v d -> (b v) d"), pseudo_train_labels + # b v d -> (b v) d + torch.cat([x0, x1]).flatten(0, 1), + pseudo_train_labels, ) with torch.no_grad(): pseudo_preds = pseudo_clf( - rearrange(torch.cat([val_x0, val_x1]), "b v d -> (b v) d") + # b v d -> (b v) d + torch.cat([val_x0, val_x1]).flatten(0, 1) ) return float(roc_auc_score(pseudo_val_labels.cpu(), pseudo_preds.cpu())) diff --git a/elk/training/train.py b/elk/training/train.py index 9ac7ad05..d156b0dc 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -12,7 +12,7 @@ from datasets import DatasetDict from functools import partial from pathlib import Path -from simple_parsing import subgroups, Serializable +from simple_parsing import field, subgroups, Serializable from sklearn.metrics import accuracy_score, roc_auc_score from torch import Tensor from tqdm.auto import tqdm @@ -41,7 +41,7 @@ class RunConfig(Serializable): net: ReporterConfig = subgroups( {"ccs": CcsReporterConfig, "eigen": EigenReporterConfig}, default="eigen" ) - optim: OptimConfig = OptimConfig() + optim: OptimConfig = field(default_factory=OptimConfig) label_frac: float = 0.0 max_gpus: int = -1 diff --git a/pyproject.toml b/pyproject.toml index 1e69ab37..b93ea61d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,12 +12,8 @@ license = {text = "MIT License"} dependencies = [ # Added Dataset.from_generator() method "datasets>=2.5.0", - # TODO: consider removing this dependency since we only use it once - "einops", # Introduced numpy.typing module "numpy>=1.20.0", - # Introduced type annotations - "prettytable>=3.5.0", # This version is old, but it's needed for certain HF tokenizers to work. "protobuf==3.20.*", # Basically any version should work as long as it supports the user's CUDA version @@ -26,8 +22,8 @@ dependencies = [ "scikit-learn>=1.0.0", # Needed for certain HF tokenizers "sentencepiece==0.1.97", - # Support for Literal types was added in 0.0.21 - "simple-parsing>=0.0.21", + # We upstreamed bugfixes for Literal types in 0.1.1 + "simple-parsing>=0.1.1", # Version 1.11 introduced Fully Sharded Data Parallel, which we plan to use soon "torch>=1.11.0", # Doesn't really matter but versions < 4.0 are very very old (pre-2016) @@ -55,7 +51,6 @@ include = ["elk*"] reportPrivateImportUsage = false [tool.pytest.ini_options] -markers = ["cpu: Marker for tests that do not depend on GPUs"] testpaths = ["tests"] [tool.setuptools.packages.find] From 964c662c7e2e2468ed3ec8949a627ae5104f3d1c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Mar 2023 22:39:01 -0700 Subject: [PATCH 07/10] [pre-commit.ci] pre-commit autoupdate (#125) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/codespell-project/codespell: v2.2.2 → v2.2.4](https://github.com/codespell-project/codespell/compare/v2.2.2...v2.2.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 368066c3..931bf58d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - id: flake8 args: ["--ignore=E203,F401,W503", --max-line-length=88] - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.2.4 hooks: - id: codespell # The promptsource templates spuriously get flagged without this From bb8fadfb7d773d595dc61704c38c64a947083032 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Wed, 15 Mar 2023 23:35:34 -0700 Subject: [PATCH 08/10] Add a calibration error statistic (#126) * Added error message for prompt-based loss and num_variants=1 * Added num_variants and ccs_prompt_var error message * changed prompt_var "Only one variant provided. Prompt variance loss will equal CCS loss." string to be accurate * changed default loss to ccs * Draft commit * Break Reporter into CcsReporter and EigenReporter * Fix transpose bug * Auto choose solver for device * Initial support for streaming VINC * Tests fr streaming VINC * Fix CcsReporter type check bug * Add fit_streaming * Platt scaling * Platt scaling by default * Add expected calibration error stat * Remove vestigial 'uniform' binning option * Rename confidences -> pred_probs * Move comment --------- Co-authored-by: Benjamin --- elk/calibration.py | 90 ++++++++++++++++++++++++++++++++++ elk/math_util.py | 1 - elk/training/eigen_reporter.py | 5 +- elk/training/reporter.py | 17 +++++-- elk/training/train.py | 10 +++- 5 files changed, 114 insertions(+), 9 deletions(-) create mode 100644 elk/calibration.py diff --git a/elk/calibration.py b/elk/calibration.py new file mode 100644 index 00000000..3d494872 --- /dev/null +++ b/elk/calibration.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass, field +from torch import Tensor +from typing import NamedTuple +import torch +import warnings + + +class CalibrationEstimate(NamedTuple): + ece: float + num_bins: int + + +@dataclass +class CalibrationError: + """Monotonic Sweep Calibration Error for binary problems. + + This method estimates the True Calibration Error (TCE) by searching for the largest + number of bins into which the data can be split that preserves the monotonicity + of the predicted confidence -> empirical accuracy mapping. We use equal mass bins + (quantiles) instead of equal width bins. Roelofs et al. (2020) show that this + estimator has especially low bias in simulations where the TCE is analytically + computable, and is hyperparameter-free (except for the type of norm used). + + Paper: "Mitigating Bias in Calibration Error Estimation" by Roelofs et al. (2020) + Link: https://arxiv.org/abs/2012.08668 + """ + + labels: list[Tensor] = field(default_factory=list) + pred_probs: list[Tensor] = field(default_factory=list) + + def update(self, labels: Tensor, probs: Tensor) -> "CalibrationError": + labels, probs = labels.detach().flatten(), probs.detach().flatten() + assert labels.shape == probs.shape + assert torch.is_floating_point(probs) + + self.labels.append(probs) + self.pred_probs.append(labels) + return self + + def compute(self, p: int = 2) -> CalibrationEstimate: + """Compute the expected calibration error. + + Args: + p: The norm to use for the calibration error. Defaults to 2 (Euclidean). + """ + labels = torch.cat(self.labels) + pred_probs = torch.cat(self.pred_probs) + + n = len(pred_probs) + if n < 2: + raise ValueError("Not enough data to compute calibration error.") + + # Sort the predictions and labels + pred_probs, indices = pred_probs.sort() + labels = labels[indices].float() + + # Search for the largest number of bins which preserves monotonicity. + # Based on Algorithm 1 in Roelofs et al. (2020). + # Using a single bin is guaranteed to be monotonic, so we start there. + b_star, accs_star = 1, labels.mean().unsqueeze(0) + for b in range(2, n + 1): + # Split into (nearly) equal mass bins + freqs = torch.stack([h.mean() for h in labels.tensor_split(b)]) + + # This binning is not strictly monotonic, let's break + if not torch.all(freqs[1:] > freqs[:-1]): + break + + elif not torch.all(freqs * (1 - freqs)): + warnings.warn( + "Calibration error estimate may be unreliable due to insufficient" + " data in some bins." + ) + break + + # Save the current binning, it's monotonic and may be the best one + else: + accs_star = freqs + b_star = b + + # Split into (nearly) equal mass bins. They won't be exactly equal, so we + # still weight the bins by their size. + conf_bins = pred_probs.tensor_split(b_star) + w = torch.tensor([len(c) / n for c in conf_bins]) + + # See the definition of ECE_sweep in Equation 8 of Roelofs et al. (2020) + mean_confs = torch.stack([c.mean() for c in conf_bins]) + ece = torch.sum(w * torch.abs(accs_star - mean_confs) ** p) ** (1 / p) + + return CalibrationEstimate(float(ece), b_star) diff --git a/elk/math_util.py b/elk/math_util.py index 6ab7ef3f..7b5cd38c 100644 --- a/elk/math_util.py +++ b/elk/math_util.py @@ -1,5 +1,4 @@ from torch import Tensor -from typing import Optional import math import random import torch diff --git a/elk/training/eigen_reporter.py b/elk/training/eigen_reporter.py index 3218ea70..5d301b80 100644 --- a/elk/training/eigen_reporter.py +++ b/elk/training/eigen_reporter.py @@ -1,11 +1,10 @@ """An ELK reporter network.""" from ..math_util import cov_mean_fused -from .reporter import EvalResult, Reporter, ReporterConfig -from copy import deepcopy +from .reporter import Reporter, ReporterConfig from dataclasses import dataclass from torch import nn, optim, Tensor -from typing import Optional, Sequence +from typing import Optional import torch diff --git a/elk/training/reporter.py b/elk/training/reporter.py index f5963df8..ea8a4406 100644 --- a/elk/training/reporter.py +++ b/elk/training/reporter.py @@ -1,5 +1,6 @@ """An ELK reporter network.""" +from ..calibration import CalibrationError from .classifier import Classifier from abc import ABC, abstractmethod from dataclasses import dataclass @@ -22,6 +23,7 @@ class EvalResult(NamedTuple): acc: float cal_acc: float auroc: float + ece: float @dataclass @@ -184,15 +186,21 @@ def score(self, labels: Tensor, x_pos: Tensor, x_neg: Tensor) -> EvalResult: pred_probs = self.predict(x_pos, x_neg) + # makes `num_variants` copies of each label, all within a single + # dimension of size `num_variants * n`, such that the labels align + # with pred_probs.flatten() + broadcast_labels = labels.repeat_interleave(pred_probs.shape[1]).float() + cal_err = ( + CalibrationError() + .update(broadcast_labels.cpu(), pred_probs.cpu()) + .compute() + ) + # Calibrated accuracy cal_thresh = pred_probs.float().quantile(labels.float().mean()) cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(torch.int) raw_preds = pred_probs.gt(0.5).squeeze(1).to(torch.int) - # makes `num_variants` copies of each label, all within a single - # dimension of size `num_variants * n`, such that the labels align - # with pred_probs.flatten() - broadcast_labels = labels.repeat_interleave(pred_probs.shape[1]) # roc_auc_score only takes flattened input auroc = float(roc_auc_score(broadcast_labels.cpu(), pred_probs.cpu().flatten())) cal_acc = cal_preds.flatten().eq(broadcast_labels).float().mean() @@ -202,4 +210,5 @@ def score(self, labels: Tensor, x_pos: Tensor, x_neg: Tensor) -> EvalResult: acc=torch.max(raw_acc, 1 - raw_acc).item(), cal_acc=torch.max(cal_acc, 1 - cal_acc).item(), auroc=max(auroc, 1 - auroc), + ece=cal_err.ece, ) diff --git a/elk/training/train.py b/elk/training/train.py index d156b0dc..3ecd3e62 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -169,7 +169,15 @@ def train(cfg: RunConfig, out_dir: Optional[Path] = None): devices = select_usable_devices(cfg.max_gpus) num_devices = len(devices) - cols = ["layer", "pseudo_auroc", "train_loss", "acc", "cal_acc", "auroc"] + cols = [ + "layer", + "pseudo_auroc", + "train_loss", + "acc", + "cal_acc", + "auroc", + "ece", + ] if not cfg.skip_baseline: cols += ["lr_auroc", "lr_acc"] From dac13316453dfdfe4af998a7a703bf80f4a02a95 Mon Sep 17 00:00:00 2001 From: Alex Mallen <35092692+AlexTMallen@users.noreply.github.com> Date: Thu, 16 Mar 2023 00:23:19 -0700 Subject: [PATCH 09/10] add support for piqa and super_glue copa (#131) * Added error message for prompt-based loss and num_variants=1 * Added num_variants and ccs_prompt_var error message * changed prompt_var "Only one variant provided. Prompt variance loss will equal CCS loss." string to be accurate * changed default loss to ccs * Draft commit * Break Reporter into CcsReporter and EigenReporter * Fix transpose bug * Auto choose solver for device * Initial support for streaming VINC * Tests fr streaming VINC * Fix CcsReporter type check bug * Add fit_streaming * Platt scaling * Platt scaling by default * cleanup eigen_reporter * rename contrastive_cov * fix duplicate "intracluster_cov_M2" * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add --net to readme * Update README.md * Update README.md * rename EigenReporter attributes in test_eigen_reporter.py * add support for piqa and super_glue copa --------- Co-authored-by: Benjamin Co-authored-by: Nora Belrose Co-authored-by: Walter Laurito Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- elk/extraction/extraction.py | 20 ++- .../templates/piqa/templates.yaml | 82 ++++++------ .../templates/super_glue/copa/templates.yaml | 120 +++++++++--------- 3 files changed, 116 insertions(+), 106 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 5ca7b794..d0554c19 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -219,25 +219,35 @@ def extract(cfg: ExtractionConfig, max_gpus: int = -1) -> DatasetDict: """Extract hidden states from a model and return a `DatasetDict` containing them.""" def get_splits() -> SplitDict: + # Valid splits: TrVal, TrTe, ValTe, TrValTe(->TrVal) base_splits = assert_type(SplitDict, info.splits) splits = set(base_splits) & {Split.TRAIN, Split.VALIDATION, Split.TEST} - if Split.VALIDATION in splits and Split.TEST in splits: + if ( + Split.VALIDATION in splits + and Split.TEST in splits + and Split.TRAIN in splits + ): splits.remove(Split.TEST) + assert len(splits) == 2, "Must have at least two of train, val, and test splits" - assert len(splits) == 2, "Must have train and val/test splits" - val_split = Split.VALIDATION if Split.VALIDATION in splits else Split.TEST + train_split = Split.TRAIN if Split.TRAIN in splits else Split.VALIDATION + val_split = ( + Split.VALIDATION + if Split.VALIDATION and Split.TRAIN in splits + else Split.TEST + ) # grab the max number of examples from the config for each split limit = ( { - Split.TRAIN: cfg.prompts.max_examples[0], + train_split: cfg.prompts.max_examples[0], val_split: cfg.prompts.max_examples[0] if len(cfg.prompts.max_examples) == 1 else cfg.prompts.max_examples[1], } if cfg.prompts.max_examples else { - Split.TRAIN: int(1e100), + train_split: int(1e100), val_split: int(1e100), } ) diff --git a/elk/promptsource/templates/piqa/templates.yaml b/elk/promptsource/templates/piqa/templates.yaml index b9a60764..a94de3a8 100644 --- a/elk/promptsource/templates/piqa/templates.yaml +++ b/elk/promptsource/templates/piqa/templates.yaml @@ -85,22 +85,22 @@ templates: original_task: true name: pick_correct_choice_index reference: '' - 5f4b4645-9438-4375-9062-083130e6d04e: !Template - answer_choices: null - id: 5f4b4645-9438-4375-9062-083130e6d04e - jinja: "Given a goal and a wrong solution, rewrite it to give a correct solution.\n\ - Goal: {{goal}} \nSolution: {{[sol1, sol2][1 - label]}}\nCorrected solution:\n\ - |||\n{{[sol1, sol2][label]}}\n" - metadata: !TemplateMetadata - choices_in_prompt: false - languages: - - en - metrics: - - BLEU - - ROUGE - original_task: false - name: Correct the solution - reference: '' + # 5f4b4645-9438-4375-9062-083130e6d04e: !Template + # answer_choices: null + # id: 5f4b4645-9438-4375-9062-083130e6d04e + # jinja: "Given a goal and a wrong solution, rewrite it to give a correct solution.\n\ + # Goal: {{goal}} \nSolution: {{[sol1, sol2][1 - label]}}\nCorrected solution:\n\ + # |||\n{{[sol1, sol2][label]}}\n" + # metadata: !TemplateMetadata + # choices_in_prompt: false + # languages: + # - en + # metrics: + # - BLEU + # - ROUGE + # original_task: false + # name: Correct the solution + # reference: '' 94c39589-7bfb-4c09-9337-672369459545: !Template answer_choices: '{{sol1}} ||| {{sol2}}' id: 94c39589-7bfb-4c09-9337-672369459545 @@ -238,28 +238,28 @@ templates: original_task: false name: Does this solution make sense? sol1 reference: '' - f42cd457-a14b-465a-a139-d7d2407a3bac: !Template - answer_choices: null - id: f42cd457-a14b-465a-a139-d7d2407a3bac - jinja: 'Sentence: {{goal}} {{sol1[0].lower() + sol1[1:]}} - - If the sentence does not make sense, correct it so that it does make sense. - Otherwise, just copy it. - - Answer: - - ||| - - {{goal}} {{[sol1[0].lower() + sol1[1:], sol2[0].lower() + sol2[1:]][label]}} - - ' - metadata: !TemplateMetadata - choices_in_prompt: false - languages: - - en - metrics: - - BLEU - - ROUGE - original_task: false - name: 'Correct the solution if false: from sol 1' - reference: '' + # f42cd457-a14b-465a-a139-d7d2407a3bac: !Template + # answer_choices: null + # id: f42cd457-a14b-465a-a139-d7d2407a3bac + # jinja: 'Sentence: {{goal}} {{sol1[0].lower() + sol1[1:]}} + + # If the sentence does not make sense, correct it so that it does make sense. + # Otherwise, just copy it. + + # Answer: + + # ||| + + # {{goal}} {{[sol1[0].lower() + sol1[1:], sol2[0].lower() + sol2[1:]][label]}} + + # ' + # metadata: !TemplateMetadata + # choices_in_prompt: false + # languages: + # - en + # metrics: + # - BLEU + # - ROUGE + # original_task: false + # name: 'Correct the solution if false: from sol 1' + # reference: '' diff --git a/elk/promptsource/templates/super_glue/copa/templates.yaml b/elk/promptsource/templates/super_glue/copa/templates.yaml index 5c48d168..f21781d8 100644 --- a/elk/promptsource/templates/super_glue/copa/templates.yaml +++ b/elk/promptsource/templates/super_glue/copa/templates.yaml @@ -22,21 +22,21 @@ templates: original_task: true name: exercise reference: '' - 150789fe-e309-47a1-82c9-0a4dc2c6b12b: !Template - answer_choices: '{{choice1}} ||| {{choice2}}' - id: 150789fe-e309-47a1-82c9-0a4dc2c6b12b - jinja: "{% if question == \"effect\" %} \n{{ premise }} What could happen next,\ - \ \"{{ answer_choices[0] }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label\ - \ != -1 %}{{ answer_choices[label] }}{%endif%}\n{% endif %}" - metadata: !TemplateMetadata - choices_in_prompt: true - languages: - - en - metrics: - - Accuracy - original_task: true - name: "\u2026What could happen next, C1 or C2?" - reference: '' + # 150789fe-e309-47a1-82c9-0a4dc2c6b12b: !Template + # answer_choices: '{{choice1}} ||| {{choice2}}' + # id: 150789fe-e309-47a1-82c9-0a4dc2c6b12b + # jinja: "{% if question == \"effect\" %} \n{{ premise }} What could happen next,\ + # \ \"{{ answer_choices[0] }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label\ + # \ != -1 %}{{ answer_choices[label] }}{%endif%}\n{% endif %}" + # metadata: !TemplateMetadata + # choices_in_prompt: true + # languages: + # - en + # metrics: + # - Accuracy + # original_task: true + # name: "\u2026What could happen next, C1 or C2?" + # reference: '' 4d879cbe-2fd7-424a-9d78-3f5200313fba: !Template answer_choices: '{{choice1}} ||| {{choice2}}' id: 4d879cbe-2fd7-424a-9d78-3f5200313fba @@ -105,21 +105,21 @@ templates: original_task: true name: "C1 or C2? premise, so/because\u2026" reference: "Adapted from Perez et al. 2021 and Schick & Sch\xFCtz 2021." - 84da62c2-9440-4cfc-bdd4-d70c65e33a82: !Template - answer_choices: '{{choice1}} ||| {{choice2}}' - id: 84da62c2-9440-4cfc-bdd4-d70c65e33a82 - jinja: "{% if question == \"effect\" %} \n{{ premise }} As a result, \"{{ answer_choices[0]\ - \ }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label != -1 %}{{ answer_choices[label]\ - \ }}{%endif%}\n{% endif %}" - metadata: !TemplateMetadata - choices_in_prompt: true - languages: - - en - metrics: - - Accuracy - original_task: true - name: "\u2026As a result, C1 or C2?" - reference: '' + # 84da62c2-9440-4cfc-bdd4-d70c65e33a82: !Template + # answer_choices: '{{choice1}} ||| {{choice2}}' + # id: 84da62c2-9440-4cfc-bdd4-d70c65e33a82 + # jinja: "{% if question == \"effect\" %} \n{{ premise }} As a result, \"{{ answer_choices[0]\ + # \ }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label != -1 %}{{ answer_choices[label]\ + # \ }}{%endif%}\n{% endif %}" + # metadata: !TemplateMetadata + # choices_in_prompt: true + # languages: + # - en + # metrics: + # - Accuracy + # original_task: true + # name: "\u2026As a result, C1 or C2?" + # reference: '' 8ce80f8a-239e-4393-892c-f63dbb0d9929: !Template answer_choices: '{{choice1}} ||| {{choice2}}' id: 8ce80f8a-239e-4393-892c-f63dbb0d9929 @@ -135,21 +135,21 @@ templates: original_task: true name: best_option reference: '' - 8cf2ba73-aee5-4651-b5d4-b1b88afe4abb: !Template - answer_choices: '{{choice1}} ||| {{choice2}}' - id: 8cf2ba73-aee5-4651-b5d4-b1b88afe4abb - jinja: "{% if question == \"cause\" %} \n{{ premise }} Which may be caused by\ - \ \"{{ answer_choices[0] }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label\ - \ != -1 %}{{ answer_choices[label] }}{%endif%}\n{% endif %}" - metadata: !TemplateMetadata - choices_in_prompt: true - languages: - - en - metrics: - - Accuracy - original_task: true - name: "\u2026which may be caused by" - reference: '' + # 8cf2ba73-aee5-4651-b5d4-b1b88afe4abb: !Template + # answer_choices: '{{choice1}} ||| {{choice2}}' + # id: 8cf2ba73-aee5-4651-b5d4-b1b88afe4abb + # jinja: "{% if question == \"cause\" %} \n{{ premise }} Which may be caused by\ + # \ \"{{ answer_choices[0] }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label\ + # \ != -1 %}{{ answer_choices[label] }}{%endif%}\n{% endif %}" + # metadata: !TemplateMetadata + # choices_in_prompt: true + # languages: + # - en + # metrics: + # - Accuracy + # original_task: true + # name: "\u2026which may be caused by" + # reference: '' a1f9951e-2b6b-4530-9636-9cdf4c1658c5: !Template answer_choices: '{{choice1}} ||| {{choice2}}' id: a1f9951e-2b6b-4530-9636-9cdf4c1658c5 @@ -191,21 +191,21 @@ templates: original_task: true name: cause_effect reference: '' - a8bf11c3-bea2-45ba-a533-957d8bee5e2e: !Template - answer_choices: '{{choice1}} ||| {{choice2}}' - id: a8bf11c3-bea2-45ba-a533-957d8bee5e2e - jinja: "{% if question == \"cause\" %} \n{{ premise }} Why? \"{{ answer_choices[0]\ - \ }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label != -1 %}{{ answer_choices[label]\ - \ }}{%endif%}\n{% endif %}" - metadata: !TemplateMetadata - choices_in_prompt: true - languages: - - en - metrics: - - Accuracy - original_task: true - name: "\u2026why? C1 or C2" - reference: '' + # a8bf11c3-bea2-45ba-a533-957d8bee5e2e: !Template + # answer_choices: '{{choice1}} ||| {{choice2}}' + # id: a8bf11c3-bea2-45ba-a533-957d8bee5e2e + # jinja: "{% if question == \"cause\" %} \n{{ premise }} Why? \"{{ answer_choices[0]\ + # \ }}\" or \"{{ answer_choices[1] }}\"? ||| {% if label != -1 %}{{ answer_choices[label]\ + # \ }}{%endif%}\n{% endif %}" + # metadata: !TemplateMetadata + # choices_in_prompt: true + # languages: + # - en + # metrics: + # - Accuracy + # original_task: true + # name: "\u2026why? C1 or C2" + # reference: '' f32348cd-d3cb-4619-87b9-e24f99c78567: !Template answer_choices: '{{choice1}} ||| {{choice2}}' id: f32348cd-d3cb-4619-87b9-e24f99c78567 From 69da185a72a7147ec5f5478d0d84e66c029b2600 Mon Sep 17 00:00:00 2001 From: Nora Belrose <39116809+norabelrose@users.noreply.github.com> Date: Thu, 16 Mar 2023 11:54:20 -0700 Subject: [PATCH 10/10] Fix broken handling of splits (#132) --- elk/extraction/extraction.py | 57 +++++++++++++++--------------------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index d0554c19..66e3d128 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -219,38 +219,29 @@ def extract(cfg: ExtractionConfig, max_gpus: int = -1) -> DatasetDict: """Extract hidden states from a model and return a `DatasetDict` containing them.""" def get_splits() -> SplitDict: - # Valid splits: TrVal, TrTe, ValTe, TrValTe(->TrVal) - base_splits = assert_type(SplitDict, info.splits) - splits = set(base_splits) & {Split.TRAIN, Split.VALIDATION, Split.TEST} - if ( - Split.VALIDATION in splits - and Split.TEST in splits - and Split.TRAIN in splits - ): - splits.remove(Split.TEST) - assert len(splits) == 2, "Must have at least two of train, val, and test splits" - - train_split = Split.TRAIN if Split.TRAIN in splits else Split.VALIDATION - val_split = ( - Split.VALIDATION - if Split.VALIDATION and Split.TRAIN in splits - else Split.TEST - ) + available_splits = assert_type(SplitDict, info.splits) + priorities = { + Split.TRAIN: 0, + Split.VALIDATION: 1, + Split.TEST: 2, + } + splits = sorted(available_splits, key=lambda k: priorities.get(k, 100)) + assert len(splits) >= 2, "Must have train and val/test splits" - # grab the max number of examples from the config for each split - limit = ( - { - train_split: cfg.prompts.max_examples[0], - val_split: cfg.prompts.max_examples[0] - if len(cfg.prompts.max_examples) == 1 - else cfg.prompts.max_examples[1], - } - if cfg.prompts.max_examples - else { - train_split: int(1e100), - val_split: int(1e100), - } - ) + # Take the first two splits + splits = splits[:2] + print(f"Using '{splits[0]}' for training and '{splits[1]}' for validation") + + # Empty list means no limit + limit_list = cfg.prompts.max_examples + if not limit_list: + limit_list = [int(1e100)] + + # Broadcast the limit to all splits + if len(limit_list) == 1: + limit_list *= len(splits) + + limit = {k: v for k, v in zip(splits, limit_list)} return SplitDict( { k: SplitInfo( @@ -258,10 +249,10 @@ def get_splits() -> SplitDict: num_examples=min(limit[k], v.num_examples), dataset_name=v.dataset_name, ) - for k, v in base_splits.items() + for k, v in available_splits.items() if k in splits }, - dataset_name=base_splits.dataset_name, + dataset_name=available_splits.dataset_name, ) model_cfg = AutoConfig.from_pretrained(cfg.model)