Skip to content

Commit

Permalink
Make datasets & transformers optional dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Aug 18, 2023
1 parent 6bd9895 commit 1547093
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 184 deletions.
7 changes: 1 addition & 6 deletions concept_erasure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from .concept_scrubber import ConceptScrubber
from .data import chunk_and_tokenize
from .leace import ErasureMethod, LeaceEraser, LeaceFitter
from .oracle import OracleEraser, OracleFitter
from .random_scrub import random_scrub
from .shrinkage import optimal_linear_shrinkage
from .utils import assert_type, chunk
from .utils import assert_type

__all__ = [
"assert_type",
"chunk",
"chunk_and_tokenize",
"optimal_linear_shrinkage",
"random_scrub",
"ConceptScrubber",
"LeaceEraser",
"LeaceFitter",
Expand Down
6 changes: 4 additions & 2 deletions concept_erasure/concept_scrubber.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Callable

from torch import Tensor, nn
from transformers import PreTrainedModel

from .leace import LeaceEraser
from .utils import assert_type, is_norm_layer, mangle_module_path
Expand Down Expand Up @@ -46,7 +45,10 @@ def pre_wrapper(_, inputs, name: str) -> tuple[Tensor | None, ...]:
key = mangle_module_path(name)
return hook_fn(key, x), *extras

# Unwrap the base model if necessary
# Unwrap the base model if necessary. This is needed to ensure we don't try to
# scrub right before the unembedding layer
from transformers import PreTrainedModel

if isinstance(model, PreTrainedModel):
model = assert_type(PreTrainedModel, model.base_model)

Expand Down
136 changes: 0 additions & 136 deletions concept_erasure/data.py

This file was deleted.

3 changes: 2 additions & 1 deletion concept_erasure/scrubbing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .auto import patch_attention_, scrub
from .random_scrub import random_scrub

__all__ = ["patch_attention_", "scrub"]
__all__ = ["patch_attention_", "random_scrub", "scrub"]
2 changes: 1 addition & 1 deletion concept_erasure/scrubbing/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def scrub(
method: ErasureMethod = "leace",
sublayers: bool = True,
) -> tuple[ConceptScrubber | None, float]:
"""Scrub a model to remove the fast attention kernels."""
"""Apply concept scrubbing to `model` on dataset `train`, returning the scrubber."""
if isinstance(model, GPTNeoXForCausalLM):
return scrub_neox(model, train, z_column, batch_size, method, affine)
elif isinstance(model, LlamaForCausalLM):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING

import torch
from torch import Tensor, nn
from transformers import PreTrainedModel

from .utils import assert_type, is_norm_layer
if (
TYPE_CHECKING
): # Don't import this unless we're type checking, since it's slow to import
from transformers import PreTrainedModel

from ..utils import is_norm_layer


@contextmanager
def random_scrub(model: PreTrainedModel, subspace_dim: int):
def random_scrub(model: "PreTrainedModel", subspace_dim: int):
"""Add hooks to the model which erase a random subspace during `forward`."""
d = model.config.hidden_size

# Unwrap the base model if necessary
if isinstance(model, PreTrainedModel):
model = assert_type(PreTrainedModel, model.base_model)

u = torch.empty(d, subspace_dim, device=model.device, dtype=model.dtype)
nn.init.orthogonal_(u)

Expand Down
24 changes: 0 additions & 24 deletions concept_erasure/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import math
from typing import Any, Type, TypeVar, cast

from torch import nn
from transformers import PreTrainedModel

T = TypeVar("T")

Expand All @@ -15,28 +13,6 @@ def assert_type(typ: Type[T], obj: Any) -> T:
return cast(typ, obj)


def chunk(seq: list[T], chunk_size: int) -> list[list[T]]:
"""Chunk a sequence into chunks of size `chunk_size`."""

# Why the hell is this not in the standard library?
return [
seq[i * chunk_size : (i + 1) * chunk_size]
for i in range(math.ceil(len(seq) / chunk_size))
]


def get_transformer_layers(model: PreTrainedModel) -> nn.ModuleList:
"""Return the `nn.ModuleList` containing the transformer layers in a model."""
assert not model.config.is_encoder_decoder, "Encoder-decoder models not supported"

lists = [mod for mod in model.modules() if isinstance(mod, nn.ModuleList)]
if not lists:
raise ValueError("Could not find transformer layers")

# Return the module list with the most parameters
return max(lists, key=lambda mod: sum(p.numel() for p in mod.parameters()))


def is_norm_layer(module: nn.Module) -> bool:
"""Return `True` if the module is a normalization layer."""
cls_name = type(module).__name__
Expand Down
6 changes: 3 additions & 3 deletions experiments/scrub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
PreTrainedTokenizerBase,
)

from concept_erasure import random_scrub
from concept_erasure.scrubbing import patch_attention_, scrub
from concept_erasure.scrubbing import patch_attention_, random_scrub, scrub
from concept_erasure.utils import assert_type


Expand Down Expand Up @@ -95,12 +94,13 @@ def evaluate(ds, model, args, nats_to_bpb):
desc="Evaluating",
total=len(ds) // args.batch_size,
)
base = assert_type(PreTrainedModel, model.base_model)

for batch in pbar:
assert isinstance(batch, dict)

tokens = assert_type(torch.Tensor, batch["input_ids"])
with random_scrub(model, subspace_dim=k):
with random_scrub(base, subspace_dim=k):
loss = model(tokens, labels=tokens).loss
losses.append(loss)

Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@ requires-python = ">=3.10"
keywords = ["fairness", "interpretability", "explainable-ai"]
license = {text = "MIT License"}
dependencies = [
"datasets",
"torch",
# 4.0 introduced the breaking change of using return_dict=True by default
"transformers>=4.0.0",
]
version = "0.2.0"
version = "0.2.1"

[project.optional-dependencies]
dev = [
Expand All @@ -24,6 +21,11 @@ dev = [
"pytest",
"pyright",
"scikit-learn",

# Integration with HuggingFace datasets and transformers for concept scrubbing
"datasets",
# 4.0 introduced the breaking change of using return_dict=True by default
"transformers>=4.0.0",
]

[project.scripts]
Expand Down

0 comments on commit 1547093

Please sign in to comment.