From 1547093f07b8fbdc2b61ae8d554691fcc32bef42 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Fri, 18 Aug 2023 02:27:13 +0000 Subject: [PATCH] Make datasets & transformers optional dependencies --- concept_erasure/__init__.py | 7 +- concept_erasure/concept_scrubber.py | 6 +- concept_erasure/data.py | 136 ------------------ concept_erasure/scrubbing/__init__.py | 3 +- concept_erasure/scrubbing/auto.py | 2 +- .../{ => scrubbing}/random_scrub.py | 15 +- concept_erasure/utils.py | 24 ---- experiments/scrub.py | 6 +- pyproject.toml | 10 +- 9 files changed, 25 insertions(+), 184 deletions(-) delete mode 100644 concept_erasure/data.py rename concept_erasure/{ => scrubbing}/random_scrub.py (72%) diff --git a/concept_erasure/__init__.py b/concept_erasure/__init__.py index 9334559..6b84acb 100644 --- a/concept_erasure/__init__.py +++ b/concept_erasure/__init__.py @@ -1,17 +1,12 @@ from .concept_scrubber import ConceptScrubber -from .data import chunk_and_tokenize from .leace import ErasureMethod, LeaceEraser, LeaceFitter from .oracle import OracleEraser, OracleFitter -from .random_scrub import random_scrub from .shrinkage import optimal_linear_shrinkage -from .utils import assert_type, chunk +from .utils import assert_type __all__ = [ "assert_type", - "chunk", - "chunk_and_tokenize", "optimal_linear_shrinkage", - "random_scrub", "ConceptScrubber", "LeaceEraser", "LeaceFitter", diff --git a/concept_erasure/concept_scrubber.py b/concept_erasure/concept_scrubber.py index 51f82f5..e22bb38 100644 --- a/concept_erasure/concept_scrubber.py +++ b/concept_erasure/concept_scrubber.py @@ -3,7 +3,6 @@ from typing import Callable from torch import Tensor, nn -from transformers import PreTrainedModel from .leace import LeaceEraser from .utils import assert_type, is_norm_layer, mangle_module_path @@ -46,7 +45,10 @@ def pre_wrapper(_, inputs, name: str) -> tuple[Tensor | None, ...]: key = mangle_module_path(name) return hook_fn(key, x), *extras - # Unwrap the base model if necessary + # Unwrap the base model if necessary. This is needed to ensure we don't try to + # scrub right before the unembedding layer + from transformers import PreTrainedModel + if isinstance(model, PreTrainedModel): model = assert_type(PreTrainedModel, model.base_model) diff --git a/concept_erasure/data.py b/concept_erasure/data.py deleted file mode 100644 index c6b19c4..0000000 --- a/concept_erasure/data.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Tools for tokenizing and manipulating text datasets.""" -import math -from multiprocessing import cpu_count -from typing import TypeVar, Union - -from datasets import Dataset, DatasetDict -from transformers import PreTrainedTokenizerBase - -T = TypeVar("T", bound=Union[Dataset, DatasetDict]) - - -def chunk_and_tokenize( - data: T, - tokenizer: PreTrainedTokenizerBase, - *, - format: str = "torch", - num_proc: int = min(cpu_count() // 2, 8), - text_key: str = "text", - max_length: int = 2048, - return_final_batch: bool = False, - load_from_cache_file: bool = True, -) -> tuple[T, float]: - """Perform GPT-style chunking and tokenization on a dataset. - - The resulting dataset will consist entirely of chunks exactly `max_length` tokens - long. Long sequences will be split into multiple chunks, and short sequences will - be merged with their neighbors, using `eos_token` as a separator. The fist token - will also always be an `eos_token`. - - Args: - data: The dataset to chunk and tokenize. - tokenizer: The tokenizer to use. - format: The format to return the dataset in, passed to `Dataset.with_format`. - num_proc: The number of processes to use for tokenization. - text_key: The key in the dataset to use as the text to tokenize. - max_length: The maximum length of a batch of input ids. - return_final_batch: Whether to return the final batch, which may be smaller - than the others. - load_from_cache_file: Whether to load from the cache file. - - Returns: - * The chunked and tokenized dataset. - * The ratio of nats to bits per byte see https://arxiv.org/pdf/2101.00027.pdf, - section 3.1. - """ - - def _tokenize_fn(x: dict[str, list]): - chunk_size = min(tokenizer.model_max_length, max_length) - sep = tokenizer.eos_token or "<|endoftext|>" - joined_text = sep.join([""] + x[text_key]) - output = tokenizer( - # Concatenate all the samples together, separated by the EOS token. - joined_text, # start with an eos token - max_length=chunk_size, - return_attention_mask=False, - return_overflowing_tokens=True, - truncation=True, - ) - - if overflow := output.pop("overflowing_tokens", None): - # Slow Tokenizers return unnested lists of ints - assert isinstance(output["input_ids"][0], int) - - # Chunk the overflow into batches of size `chunk_size` - chunks = [output["input_ids"]] + [ - overflow[i * chunk_size : (i + 1) * chunk_size] - for i in range(math.ceil(len(overflow) / chunk_size)) - ] - output = {"input_ids": chunks} - - total_tokens = sum(len(ids) for ids in output["input_ids"]) - total_bytes = len(joined_text.encode("utf-8")) - - if not return_final_batch: - # We know that the last sample will almost always be less than the max - # number of tokens, and we don't want to pad, so we just drop it. - output = {k: v[:-1] for k, v in output.items()} - - output_batch_size = len(output["input_ids"]) - - if output_batch_size == 0: - raise ValueError( - "Not enough data to create a single batch complete batch." - " Either allow the final batch to be returned," - " or supply more data." - ) - - # We need to output this in order to compute the number of bits per byte - div, rem = divmod(total_tokens, output_batch_size) - output["length"] = [div] * output_batch_size - output["length"][-1] += rem - - div, rem = divmod(total_bytes, output_batch_size) - output["bytes"] = [div] * output_batch_size - output["bytes"][-1] += rem - - return output - - data = data.map( - _tokenize_fn, - # Batching is important for ensuring that we don't waste tokens - # since we always throw away the last element of the batch we - # want to keep the batch size as large as possible - batched=True, - batch_size=2048, - num_proc=num_proc, - remove_columns=get_columns_all_equal(data), - load_from_cache_file=load_from_cache_file, - ) - total_bytes: float = sum(data["bytes"]) - total_tokens: float = sum(data["length"]) - return data.with_format(format, columns=["input_ids"]), ( - total_tokens / total_bytes - ) / math.log(2) - - -def get_columns_all_equal(dataset: Union[Dataset, DatasetDict]) -> list[str]: - """Get a single list of columns in a `Dataset` or `DatasetDict`. - - We assert the columns are the same across splits if it's a `DatasetDict`. - - Args: - dataset: The dataset to get the columns from. - - Returns: - A list of columns. - """ - if isinstance(dataset, DatasetDict): - cols_by_split = dataset.column_names.values() - columns = next(iter(cols_by_split)) - if not all(cols == columns for cols in cols_by_split): - raise ValueError("All splits must have the same columns") - - return columns - - return dataset.column_names diff --git a/concept_erasure/scrubbing/__init__.py b/concept_erasure/scrubbing/__init__.py index 275acd8..803c7ba 100644 --- a/concept_erasure/scrubbing/__init__.py +++ b/concept_erasure/scrubbing/__init__.py @@ -1,3 +1,4 @@ from .auto import patch_attention_, scrub +from .random_scrub import random_scrub -__all__ = ["patch_attention_", "scrub"] +__all__ = ["patch_attention_", "random_scrub", "scrub"] diff --git a/concept_erasure/scrubbing/auto.py b/concept_erasure/scrubbing/auto.py index fd696da..c394557 100644 --- a/concept_erasure/scrubbing/auto.py +++ b/concept_erasure/scrubbing/auto.py @@ -37,7 +37,7 @@ def scrub( method: ErasureMethod = "leace", sublayers: bool = True, ) -> tuple[ConceptScrubber | None, float]: - """Scrub a model to remove the fast attention kernels.""" + """Apply concept scrubbing to `model` on dataset `train`, returning the scrubber.""" if isinstance(model, GPTNeoXForCausalLM): return scrub_neox(model, train, z_column, batch_size, method, affine) elif isinstance(model, LlamaForCausalLM): diff --git a/concept_erasure/random_scrub.py b/concept_erasure/scrubbing/random_scrub.py similarity index 72% rename from concept_erasure/random_scrub.py rename to concept_erasure/scrubbing/random_scrub.py index 3fa7050..7377b13 100644 --- a/concept_erasure/random_scrub.py +++ b/concept_erasure/scrubbing/random_scrub.py @@ -1,22 +1,23 @@ from contextlib import contextmanager from functools import partial +from typing import TYPE_CHECKING import torch from torch import Tensor, nn -from transformers import PreTrainedModel -from .utils import assert_type, is_norm_layer +if ( + TYPE_CHECKING +): # Don't import this unless we're type checking, since it's slow to import + from transformers import PreTrainedModel + +from ..utils import is_norm_layer @contextmanager -def random_scrub(model: PreTrainedModel, subspace_dim: int): +def random_scrub(model: "PreTrainedModel", subspace_dim: int): """Add hooks to the model which erase a random subspace during `forward`.""" d = model.config.hidden_size - # Unwrap the base model if necessary - if isinstance(model, PreTrainedModel): - model = assert_type(PreTrainedModel, model.base_model) - u = torch.empty(d, subspace_dim, device=model.device, dtype=model.dtype) nn.init.orthogonal_(u) diff --git a/concept_erasure/utils.py b/concept_erasure/utils.py index e106f38..f95d3c4 100644 --- a/concept_erasure/utils.py +++ b/concept_erasure/utils.py @@ -1,8 +1,6 @@ -import math from typing import Any, Type, TypeVar, cast from torch import nn -from transformers import PreTrainedModel T = TypeVar("T") @@ -15,28 +13,6 @@ def assert_type(typ: Type[T], obj: Any) -> T: return cast(typ, obj) -def chunk(seq: list[T], chunk_size: int) -> list[list[T]]: - """Chunk a sequence into chunks of size `chunk_size`.""" - - # Why the hell is this not in the standard library? - return [ - seq[i * chunk_size : (i + 1) * chunk_size] - for i in range(math.ceil(len(seq) / chunk_size)) - ] - - -def get_transformer_layers(model: PreTrainedModel) -> nn.ModuleList: - """Return the `nn.ModuleList` containing the transformer layers in a model.""" - assert not model.config.is_encoder_decoder, "Encoder-decoder models not supported" - - lists = [mod for mod in model.modules() if isinstance(mod, nn.ModuleList)] - if not lists: - raise ValueError("Could not find transformer layers") - - # Return the module list with the most parameters - return max(lists, key=lambda mod: sum(p.numel() for p in mod.parameters())) - - def is_norm_layer(module: nn.Module) -> bool: """Return `True` if the module is a normalization layer.""" cls_name = type(module).__name__ diff --git a/experiments/scrub.py b/experiments/scrub.py index d6ef4cf..b5ccaed 100644 --- a/experiments/scrub.py +++ b/experiments/scrub.py @@ -12,8 +12,7 @@ PreTrainedTokenizerBase, ) -from concept_erasure import random_scrub -from concept_erasure.scrubbing import patch_attention_, scrub +from concept_erasure.scrubbing import patch_attention_, random_scrub, scrub from concept_erasure.utils import assert_type @@ -95,12 +94,13 @@ def evaluate(ds, model, args, nats_to_bpb): desc="Evaluating", total=len(ds) // args.batch_size, ) + base = assert_type(PreTrainedModel, model.base_model) for batch in pbar: assert isinstance(batch, dict) tokens = assert_type(torch.Tensor, batch["input_ids"]) - with random_scrub(model, subspace_dim=k): + with random_scrub(base, subspace_dim=k): loss = model(tokens, labels=tokens).loss losses.append(loss) diff --git a/pyproject.toml b/pyproject.toml index 28832a8..b0ac95f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,12 +10,9 @@ requires-python = ">=3.10" keywords = ["fairness", "interpretability", "explainable-ai"] license = {text = "MIT License"} dependencies = [ - "datasets", "torch", - # 4.0 introduced the breaking change of using return_dict=True by default - "transformers>=4.0.0", ] -version = "0.2.0" +version = "0.2.1" [project.optional-dependencies] dev = [ @@ -24,6 +21,11 @@ dev = [ "pytest", "pyright", "scikit-learn", + + # Integration with HuggingFace datasets and transformers for concept scrubbing + "datasets", + # 4.0 introduced the breaking change of using return_dict=True by default + "transformers>=4.0.0", ] [project.scripts]