diff --git a/README.md b/README.md index efb1144..03f4936 100644 --- a/README.md +++ b/README.md @@ -60,12 +60,19 @@ approaching 1 **The perplexity and BLEU4 scores are those reported in the original paper and not measured via this codebase. +And unconditional perturber models: + +| | Base model | Parameters | Perplexity | Perplexity (perturbed idx) | Perplexity (word) | Perplexity (attribute) | +|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------|------------|------------|----------------------------|-------------------|------------------------| +| [unconditional-perturber-small](https://huggingface.co/fairnlp/unconditional-perturber-small) | [bart-small](https://huggingface.co/lucadiliello/bart-small) | 70m | 1.101 | 4.265 | 5.262 | 5.635 | +| [unconditional-perturber-base](https://huggingface.co/fairnlp/unconditional-perturber-base) | [bart-base](https://huggingface.co/facebook/bart-base) | 139m | 1.082 | 2.897 | 4.537 | 5.328 | + # Roadmap - [x] Add default perturber model - [x] Pretrain small and medium perturber models -- [ ] Train model to identify target words and attributes -- [ ] Add training of unconditional perturber models (i.e. only get a sentence, no target word/attribute) +- [x] Add training of unconditional perturber models (i.e. only get a sentence, no target word/attribute) +- [ ] Add - [ ] Add self-training by pretraining perturber base model (e.g. BART) on self-perturbed data Other features could include diff --git a/perturbers/data/panda_dict.py b/perturbers/data/panda_dict.py index 79c469c..19e3bf4 100644 --- a/perturbers/data/panda_dict.py +++ b/perturbers/data/panda_dict.py @@ -1,7 +1,9 @@ +from typing import Optional, List + from datasets import concatenate_datasets, load_dataset GENDER_ATTRIBUTES = {"man", "woman", "non-binary"} -RACE_ATTRIBUTES = {"black", "white", "hispanic", "native-american", "pacific-islander"} +RACE_ATTRIBUTES = {"black", "white", "asian", "hispanic", "native-american", "pacific-islander"} AGE_ATTRIBUTES = {"child", "young", "middle-aged", "senior", "adult"} ALL_ATTRIBUTES = GENDER_ATTRIBUTES | RACE_ATTRIBUTES | AGE_ATTRIBUTES @@ -21,3 +23,32 @@ def get_panda_dict() -> dict[str, list[str]]: perturbation_dict[word] = perturbation_dict.get(word, []) + [attribute] sorted_dict = dict(reversed(sorted(perturbation_dict.items(), key=lambda item: len(item[1])))) return sorted_dict + + +def get_attribute_tokens(attribute_set: Optional[set[str]] = None) -> List[str]: + """ + Creates specific attribute tokens + + Args: + attribute_set: Set of attributes to be used for token generation. If None, all attributes are used. + + Returns: + A list of attribute tokens. + """ + if attribute_set is None: + attribute_set = ALL_ATTRIBUTES + + return [attribute_to_token(attr) for attr in attribute_set] + + +def attribute_to_token(attribute: str) -> str: + """ + Converts an attribute to a token + + Args: + attribute: The attribute to be converted to a token + + Returns: + The token corresponding to the attribute + """ + return f"<{attribute.upper().replace('-', '_').replace(' ', '_')}>" diff --git a/perturbers/modeling/perturber.py b/perturbers/modeling/perturber.py index 5b07bd4..321c021 100644 --- a/perturbers/modeling/perturber.py +++ b/perturbers/modeling/perturber.py @@ -1,10 +1,14 @@ +import logging import random from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, Literal, Tuple, List +import numpy as np +import torch from transformers import BartForConditionalGeneration, AutoTokenizer, PreTrainedModel +from transformers.generation.configuration_utils import GenerationConfig -from perturbers.data.panda_dict import get_panda_dict, ALL_ATTRIBUTES +from perturbers.data.panda_dict import get_panda_dict, attribute_to_token, ALL_ATTRIBUTES @dataclass @@ -12,6 +16,12 @@ class PerturberConfig: sep_token: str = '' pert_sep_token: str = '' max_length: int = 128 + conditional: bool = True + + +@dataclass +class UnconditionalPerturberConfig(PerturberConfig): + conditional: bool = False class Perturber: @@ -45,16 +55,84 @@ def __init__( self.model = BartForConditionalGeneration.from_pretrained(model_name) self.config.sep_token = "," self.config.pert_sep_token = "" + if config is None and "unconditional" in model_name: + logging.info("Inferring unconditional perturber from model name") + self.config.conditional = False + + self.attribute_to_token = {a: attribute_to_token(a) for a in ALL_ATTRIBUTES} + self.token_to_attribute = {t: a for a, t in self.attribute_to_token.items()} self.model.config.max_length = self.config.max_length self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) self.tokenizer.add_tokens([self.config.pert_sep_token], special_tokens=True) + if self.model.model.get_input_embeddings().num_embeddings != len(self.tokenizer): + logging.warning("Number of tokens in tokenizer does not match number of tokens in model. The model " + "embeddings will be resized by adding random weights.") + self.model.model.resize_token_embeddings(len(self.tokenizer)) self.panda_dict = get_panda_dict() self.input_template = PerturberTemplate(sep=self.config.sep_token, pert_sep=self.config.pert_sep_token, - original=model_name == "facebook/perturber") + original=model_name == "facebook/perturber", + conditional=self.config.conditional) - def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs=None) -> str: + def get_attribute_probabilities(self, input_txt: str, softmax: bool = True, + attributes: Optional[dict[str, str]] = None) -> dict[str, float]: + # TODO add option to normalize for prior, e.g. p(man) < p(woman) due to PANDA imbalance + """ + Given a certain input string, get a conditional probability distribution over a specified set of attributes. + + Args: + input_txt: string to base the probability distribution on. + + softmax: if True, apply a softmax over the logits, else return the likelihood values. + + attributes: mapping of attributes to tokens for which to get probabilities. Default is all attributes. + """ + if attributes is None: + attributes = self.attribute_to_token + + if self.config.conditional: + raise RuntimeError("Attribute classification is not possible for conditional perturber models") + decoder_inputs = self.tokenizer(self.tokenizer.bos_token, return_tensors='pt', add_special_tokens=False) + logits = self.model( + **self.tokenizer(input_txt, return_tensors='pt'), + decoder_input_ids=decoder_inputs.data["input_ids"], + decoder_attention_mask=decoder_inputs.data["attention_mask"], + ).logits[0, -1] + assert len(logits.shape) == 1 and logits.shape[0] == len(self.tokenizer.vocab) + + if softmax: + token_logits = torch.tensor([logits[self.tokenizer.vocab[t]] for t in attributes.values()]) + softmax_probs = torch.softmax(token_logits, dim=0) + return {a: float(prob) for a, prob in zip(attributes.keys(), softmax_probs)} + else: + return {a: float(logits[self.tokenizer.vocab[t]].exp()) for a, t in attributes.items()} + + def generate_conditions(self, input_txt: str, n_permutations: int, tokenizer_kwargs: dict) -> List[tuple[str, str]]: + if self.config.conditional: + raise RuntimeError("Attribute classification is not possible for conditional perturber models") + gen_config = GenerationConfig.from_model_config(self.model.config) + gen_config.update(eos_token_id=self.tokenizer.vocab[self.config.pert_sep_token]) + generation = self.model.generate( + **self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs), + generation_config=gen_config, + max_new_tokens=self.model.config.max_length, + num_return_sequences=n_permutations, + num_beams=n_permutations, + ) + attribute_tokens = generation[:, 2] + target_tokens = generation[:, 3:] + + # Hack to prevent double brackets from InputTemplate + attributes = [self.token_to_attribute.get(a) for a in self.tokenizer.batch_decode(attribute_tokens, + max_new_tokens=self.model.config.max_length)] + target_words = [w.lstrip() for w in self.tokenizer.batch_decode(target_tokens, skip_special_tokens=True, + max_new_tokens=self.model.config.max_length)] + # Filter attribute hallucinations (unlikely) + return [(w, a) for idx, (w, a) in enumerate(zip(target_words, attributes)) if a is not None] + + def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs: Optional[dict] = None, + generate_kwargs: Optional[dict] = None) -> Tuple[str, float]: """ Generates a perturbed version of the input text. @@ -67,24 +145,59 @@ def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenize tokenizer_kwargs: Additional keyword arguments to be passed to the tokenizer + generate_kwargs: Additional keyword arguments to be passed to the generate method + Returns: - Perturbed version of the input text + Perturbed version of the input text along with the average token probability """ if tokenizer_kwargs is None: tokenizer_kwargs = {} - if attribute and attribute not in ALL_ATTRIBUTES: - raise ValueError(f"Attribute {attribute} not in {ALL_ATTRIBUTES}") - input_txt = self.input_template(input_txt, word, attribute) - output_tokens = self.model.generate(**self.tokenizer(input_txt, return_tensors='pt'), **tokenizer_kwargs) - return self.tokenizer.batch_decode( - output_tokens, - skip_special_tokens=True, - max_new_tokens=self.model.config.max_length - )[0].lstrip() - - def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None, retry_unchanged=False - ) -> Union[str, NotImplementedError]: + if generate_kwargs is None: + generate_kwargs = {} + generate_kwargs["return_dict_in_generate"] = True + generate_kwargs["output_scores"] = True + + # Validate the attribute -- generated attribute is validated after generation + if self.config.conditional and attribute and attribute not in self.attribute_to_token: + raise ValueError(f"Attribute {attribute} not in {self.attribute_to_token.keys()}") + + if self.config.conditional: + input_txt = self.input_template(input_txt, word, attribute) + tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs) + outputs = self.model.generate(**tokens, **generate_kwargs) + else: + prefix = self.tokenizer.bos_token + self.input_template.get_sentence_prefix(word, attribute) + encoder_tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs) + decoder_tokens = self.tokenizer(prefix, return_tensors='pt', add_special_tokens=False, **tokenizer_kwargs) + outputs = self.model.generate( + input_ids=encoder_tokens.data["input_ids"], + attention_mask=encoder_tokens.data["attention_mask"], + decoder_input_ids=decoder_tokens.data["input_ids"], + decoder_attention_mask=decoder_tokens.data["attention_mask"], + **generate_kwargs + ) + output_string = self._decode_generation(outputs) + probabilities = self.model.compute_transition_scores(outputs.sequences, outputs.scores).exp() + return output_string, float(probabilities.mean()) + + def _decode_generation(self, outputs): + if self.config.conditional: + decode_tokens = outputs.sequences + else: + output_string = self.tokenizer.decode( + outputs.sequences[0], skip_special_tokens=False, max_new_tokens=self.model.config.max_length + ) + output_string = self.config.pert_sep_token.join(output_string.split(self.config.pert_sep_token)[1:]) + decode_tokens = self.tokenizer(output_string, return_tensors='pt').input_ids + output_string = self.tokenizer.decode( + decode_tokens[0], skip_special_tokens=True, max_new_tokens=self.model.config.max_length + ) + return output_string.lstrip() # Remove trailing space from tokenization + + def __call__(self, input_txt: str, mode: Optional[Literal['word_list', 'highest_prob', 'classify']] = None, + tokenizer_kwargs: Optional[dict] = None, generate_kwargs: Optional[dict] = None, + n_perturbations: int = 1, early_stopping: bool = False) -> Union[str, ValueError]: """ Perturbs the input text using the specified mode and returns the perturbed text. No target word or attribute needs to be specified for this method. @@ -92,34 +205,55 @@ def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None, retry_unc Args: input_txt: The input text to be perturbed - mode: The mode to be used for perturbation. Currently, only 'word_list' is supported + mode: The mode to be used for perturbation tokenizer_kwargs: Additional keyword arguments to be passed to the tokenizer - retry_unchanged: If True, perturbation is retried with different target words/attributes until the output is - different from the input + generate_kwargs: Additional keyword arguments to be passed to the generate method + + n_perturbations: The number of perturbations to be generated + + early_stopping: If True, the first perturbation that differs from the input is returned. Returns: Perturbed version of the input text """ + # Input validation if tokenizer_kwargs is None: tokenizer_kwargs = {} - - if mode == 'highest_prob': - raise NotImplementedError # TODO - elif mode == 'word_list': + if generate_kwargs is None: + generate_kwargs = {} + if n_perturbations < 1: + raise ValueError(f"At least one perturbation needs to be tried. Received {n_perturbations} as argument.") + if mode is not None and mode not in {'word_list', 'classify'}: + raise ValueError(f"Mode {mode} is invalid. Please choose from 'word_list' or 'classify'.") + if mode is None: + mode = 'word_list' if self.config.conditional else 'classify' + elif mode == 'classify' and self.config.conditional: + raise ValueError("Conditional perturber models are not trained for attribute classification") + + # Get perturbation targets and words + if mode == 'word_list': targets = [w for w in input_txt.split(" ") if w in self.panda_dict] perturbations = [(t, perturbed) for t in targets for perturbed in self.panda_dict[t]] - random.shuffle(perturbations) - for word, attribute in perturbations: - generated_txt = self.generate(input_txt, word=word, attribute=attribute, - tokenizer_kwargs=tokenizer_kwargs) - if generated_txt != input_txt or not retry_unchanged: + elif mode == 'classify': + perturbations = self.generate_conditions(input_txt, n_perturbations, tokenizer_kwargs) + random.shuffle(perturbations) + perturbations = perturbations[:n_perturbations] + texts, probabilities = [], [] + + for word, attribute in perturbations: + generated_txt, probability = self.generate(input_txt, word=word, attribute=attribute, + tokenizer_kwargs=tokenizer_kwargs, + generate_kwargs=generate_kwargs) + if generated_txt != input_txt: + if early_stopping: return generated_txt - else: - raise NotImplementedError + else: + texts.append(generated_txt) + probabilities.append(probability) - return input_txt + return texts[np.argmax(probabilities)] if texts else input_txt class PerturberTemplate: @@ -130,9 +264,32 @@ class PerturberTemplate: their occurrences in the input text. """ - def __init__(self, sep: str = ",", pert_sep: str = "", original: bool = False) -> None: + def __init__(self, sep: str = ",", pert_sep: str = "", + original: bool = False, conditional: bool = True) -> None: self.sep = sep + self.conditional = conditional self.pert_sep = pert_sep if not original else f" {pert_sep}" def __call__(self, input_txt: str, word: str = "", attribute: str = "") -> str: - return f"{word}{self.sep} {attribute}{self.pert_sep} {input_txt}" + if not self.conditional: + return f"{attribute_to_token(attribute)} {word}{self.pert_sep} {input_txt}" + else: + return f"{word}{self.sep} {attribute}{self.pert_sep} {input_txt}" + + def get_sentence_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]: + if not self.conditional: + return f"{attribute_to_token(attribute)} {word}{self.pert_sep}" + else: + return ValueError("Sentence prefix not available for conditional perturber") + + def get_attribute_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]: + if not self.conditional: + return "" + else: + return ValueError("Attribute prefix not available for conditional perturber") + + def get_word_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]: + if not self.conditional: + return f"{attribute_to_token(attribute)}" + else: + return ValueError("Word prefix not available for conditional perturber") diff --git a/perturbers/training/core.py b/perturbers/training/core.py index 13147e9..b8b724d 100644 --- a/perturbers/training/core.py +++ b/perturbers/training/core.py @@ -11,8 +11,10 @@ from torchmetrics.text import Perplexity, BLEUScore from transformers import AutoModel, BartForConditionalGeneration # noqa 401 from transformers import AutoTokenizer, get_linear_schedule_with_warmup +from transformers import DataCollatorWithPadding from transformers import PreTrainedTokenizerBase, PreTrainedModel +from perturbers.data.panda_dict import get_attribute_tokens from perturbers.modeling.perturber import PerturberTemplate from perturbers.training.utils import TrainingConfig, get_diff_indices @@ -37,15 +39,18 @@ def __init__(self, c: TrainingConfig, tokenizer: PreTrainedTokenizerBase) -> Non self.test_batch_size = c.test_batch_size self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) - self.train_metrics = self.get_metric_dict("train") - self.val_metrics = self.get_metric_dict("val") - self.test_metrics = self.get_metric_dict("test") + self.train_metrics = self.get_metric_dict(c, "train") + self.val_metrics = self.get_metric_dict(c, "val") + self.test_metrics = self.get_metric_dict(c, "test") - def get_metric_dict(self, split: str) -> dict[str, torch.nn.Module]: + def get_metric_dict(self, c: TrainingConfig, split: str) -> dict[str, torch.nn.Module]: metrics = { f'{split}_ppl': Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device), f'{split}_ppl_perturbed': Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device), } + if not c.conditional: + metrics[f'{split}_ppl_word'] = Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device) + metrics[f'{split}_ppl_attribute'] = Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device) if split == "test": metrics[f'{split}_bleu4'] = BLEUScore(n_gram=4).to(self._device) return metrics @@ -64,8 +69,14 @@ def update_metrics( target=[[_] for _ in self.tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)], ) elif "ppl" in metric_key: - if "perturbed" in metric_key: + idx = None + if metric_key.endswith("perturbed"): idx = batch["perturbed_idx"] + elif metric_key.endswith("word"): + idx = batch["word_idx"] + elif metric_key.endswith("attribute"): + idx = batch["attribute_idx"] + if idx is not None: value = metric(preds=outputs[idx].unsqueeze(0), target=batch['labels'][idx].unsqueeze(0)) else: value = metric(preds=outputs, target=batch['labels']) @@ -112,10 +123,7 @@ def forward(self, batch: dict) -> Tuple[torch.Tensor, torch.Tensor]: return outputs.logits, outputs.loss def generate(self, batch: dict) -> List[str]: - generations = self.model.generate( - **{k: v for k, v in batch.items() if k in ["input_ids", "attention_mask"]}, - max_length=batch['input_ids'].shape[-1], - ) + generations = self.model.generate(**{k: v for k, v in batch.items() if k in ["input_ids", "attention_mask"]}) return self.tokenizer.batch_decode(generations, skip_special_tokens=True) def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[dict]]: @@ -152,29 +160,46 @@ def get_collate_fn(c: TrainingConfig, tokenizer: PreTrainedTokenizerBase, tokeni Returns: The collate function for the dataloaders """ - input_template = PerturberTemplate(sep=c.sep_token, pert_sep=c.pert_sep_token, - original=c.model_name == "facebook/perturber") + collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors='pt', padding=True) def collate_fn(batch: List) -> dict: original, perturbed = [], [] perturbed_x, perturbed_y = [], [] + attribute_x, attribute_y = [], [] + word_x, word_y = [], [] for i, item in enumerate(batch): - perturbed.append(item['perturbed']) - original.append(input_template(item["original"], item["selected_word"], item["target_attribute"])) - idx = item["perturbed_idx"] - perturbed_x += [i] * len(idx) - perturbed_y += idx + original.append(item["original"]) + perturbed.append(item["perturbed"]) + + perturbed_idx = item["perturbed_idx"] + perturbed_x += [i] * len(perturbed_idx) + perturbed_y += perturbed_idx - original = tokenizer(original, return_tensors='pt', **tokenizer_kwargs) - perturbed = tokenizer(perturbed, return_tensors='pt', **tokenizer_kwargs) + if not c.conditional: + word_idx = item["word_idx"] + word_x += [i] * len(word_idx) + word_y += word_idx - return { + attribute_idx = item["attribute_idx"] + attribute_x += [i] * len(attribute_idx) + attribute_y += attribute_idx + + original = collator(original) + perturbed = collator(perturbed) + + return_dict = { "input_ids": original["input_ids"], "attention_mask": original["attention_mask"], "labels": perturbed["input_ids"], "perturbed_idx": (perturbed_x, perturbed_y), } + if not c.conditional: + return_dict["word_idx"] = (word_x, word_y) + return_dict["attribute_idx"] = (attribute_x, attribute_y) + + return return_dict + return collate_fn @@ -205,15 +230,46 @@ def get_callbacks(c: TrainingConfig) -> List[pl.callbacks.Callback]: ] -def add_indices(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenizer_kwargs: dict) -> dict: +def preprocess_inputs(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenizer_kwargs: dict, c: TrainingConfig, + input_template: PerturberTemplate, tokenize: bool = True) -> dict: """ Add the indices of the tokens that are different between the original and perturbed text to the sample dictionary. Function signature is intended to be used with the `map` method of the Hugging Face datasets library. """ - sample["perturbed_idx"] = get_diff_indices( + + idx = get_diff_indices( tokenizer(sample['original'], **tokenizer_kwargs).data['input_ids'], - tokenizer(sample['perturbed'], **tokenizer_kwargs).data['input_ids'] + tokenizer(sample['perturbed'], **tokenizer_kwargs).data['input_ids'], ) + + if c.conditional: + sample["perturbed_idx"] = idx + sample['original'] = input_template(sample["original"], sample["selected_word"], sample["target_attribute"]) + else: + + # Account for prefix in perturbed indices + sentence_prefix = input_template.get_sentence_prefix(sample["selected_word"], sample["target_attribute"]) + sentence_offset = len(tokenizer.tokenize(sentence_prefix)) + sample["perturbed_idx"] = [i + sentence_offset for i in idx] + sample["perturbed"] = input_template(sample["perturbed"], sample["selected_word"], sample["target_attribute"]) + + # Add indices for word and attribute + n_cls_tokens = len(tokenizer.tokenize(tokenizer.bos_token)) + word_prefix = input_template.get_word_prefix(sample["selected_word"], sample["target_attribute"]) + word_offset = len(tokenizer.tokenize(word_prefix)) + n_cls_tokens + sample["word_idx"] = [i + word_offset for i in range(len(tokenizer.tokenize(sample["selected_word"])))] + + attribute_prefix = input_template.get_attribute_prefix(sample["selected_word"], sample["target_attribute"]) + attribute_offset = len(tokenizer.tokenize(attribute_prefix)) + n_cls_tokens + sample["attribute_idx"] = [attribute_offset] # Only one attribute token + + for idx_key in ["perturbed_idx", "word_idx", "attribute_idx"]: + sample[idx_key] = [i for i in sample[idx_key] if i < c.max_length] + + if tokenize: + sample['original'] = tokenizer(sample['original'], **tokenizer_kwargs) + sample['perturbed'] = tokenizer(sample['perturbed'], **tokenizer_kwargs) + return sample @@ -227,13 +283,15 @@ def train_perturber(c: TrainingConfig) -> PreTrainedModel: c.train_steps = 10 c.val_steps = 5 c.accumulate_grad_batches = 1 + c.num_workers = 0 - tokenizer = AutoTokenizer.from_pretrained(c.model_name, add_prefix_space=True) - tokenizer.add_tokens([c.sep_token, c.pert_sep_token], special_tokens=True) - tokenizer_kwargs = {"padding": True, "truncation": True, "max_length": c.max_length} + tokenizer, tokenizer_kwargs = get_tokenizer(c) model = LightningWrapper(c, tokenizer) dataset = load_dataset(c.dataset_name) + input_template = PerturberTemplate(sep=c.sep_token, pert_sep=c.pert_sep_token, + original=c.model_name == "facebook/perturber", conditional=c.conditional) + train_ds = dataset["train"] val_ds = dataset["validation"] @@ -241,8 +299,9 @@ def train_perturber(c: TrainingConfig) -> PreTrainedModel: train_ds = train_ds.select(range(128)) val_ds = val_ds.select(range(128)) - train_ds = train_ds.map(lambda x: add_indices(x, tokenizer, tokenizer_kwargs), num_proc=max(c.num_workers, 1)) - val_ds = val_ds.map(lambda x: add_indices(x, tokenizer, tokenizer_kwargs), num_proc=max(c.num_workers, 1)) + map_fn = lambda x: preprocess_inputs(x, tokenizer, tokenizer_kwargs, c, input_template) + train_ds = train_ds.map(map_fn, num_proc=max(c.num_workers, 1)) + val_ds = val_ds.map(map_fn, num_proc=max(c.num_workers, 1)) collate_fn = get_collate_fn(c, tokenizer, tokenizer_kwargs) @@ -285,3 +344,13 @@ def train_perturber(c: TrainingConfig) -> PreTrainedModel: tokenizer.push_to_hub(c.hub_repo_id) return model.model + + +def get_tokenizer(c): + tokenizer = AutoTokenizer.from_pretrained(c.model_name, add_prefix_space=True) + new_tokens = [c.sep_token, c.pert_sep_token] + if not c.conditional: + new_tokens += get_attribute_tokens() + tokenizer.add_tokens(new_tokens, special_tokens=True) + tokenizer_kwargs = {"padding": True, "truncation": True, "max_length": c.max_length} + return tokenizer, tokenizer_kwargs diff --git a/perturbers/training/utils.py b/perturbers/training/utils.py index 18f161e..96489f6 100644 --- a/perturbers/training/utils.py +++ b/perturbers/training/utils.py @@ -6,6 +6,7 @@ @dataclass class TrainingConfig: + conditional: bool = True model_name: str = "facebook/bart-large" dataset_name: str = "facebook/panda" train_batch_size: int = 64 diff --git a/setup.py b/setup.py index fb08e26..91e24c0 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ tests_require=TEST_REQUIREMENTS, extras_require=EXTRAS_REQUIREMENTS, url="https://github.com/FairNLP/perturbers", - version="0.0.0", + version="0.0.1", long_description=open('README.md').read(), long_description_content_type='text/markdown', zip_safe=False, diff --git a/tests/test_generate.py b/tests/test_generate.py new file mode 100644 index 0000000..bf5bee4 --- /dev/null +++ b/tests/test_generate.py @@ -0,0 +1,23 @@ +from perturbers import Perturber +from perturbers.modeling.perturber import PerturberConfig + + +def test_word_list(): + model = Perturber("hf-internal-testing/tiny-random-bart") + model( + mode="word_list", + input_txt="a", + generate_kwargs={"max_new_tokens": 2} + ) + + +def test_classify(): + model = Perturber( + model="hf-internal-testing/tiny-random-bart", + config=PerturberConfig(conditional=False, max_length=8), + ) + model( + mode="classify", + input_txt="a", + generate_kwargs={"max_new_tokens": 2} + ) diff --git a/tests/test_model.py b/tests/test_model.py index f22540e..698a33c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,19 +9,19 @@ def test_perturber_model(): model = Perturber() - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_ORIGINAL def test_small_perturber_model(): model = Perturber("fairnlp/perturber-small") - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_SMALL def test_base_perturber_model(): model = Perturber("fairnlp/perturber-base") - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_BASE diff --git a/tests/test_prepare_inputs.py b/tests/test_prepare_inputs.py new file mode 100644 index 0000000..690fb78 --- /dev/null +++ b/tests/test_prepare_inputs.py @@ -0,0 +1,81 @@ +from typing import List + +from perturbers.data.panda_dict import attribute_to_token +from perturbers.modeling.perturber import PerturberTemplate +from perturbers.training.core import preprocess_inputs, get_tokenizer +from perturbers.training.utils import TrainingConfig + +# mock data +sample = { + "original": "Perturbers are cool!", + "perturbed": "Perturbers are great!", + "selected_word": "cool", + "target_attribute": "non-binary", +} +perturbed_span = " great" +selected_word_span = " cool" +target_attribute_span = attribute_to_token("non-binary") + + +def get_preprocessed_sample(c: TrainingConfig) -> dict: + tokenizer, tokenizer_kwargs = get_tokenizer(c) + input_template = PerturberTemplate(sep=c.sep_token, pert_sep=c.pert_sep_token, + original=c.model_name == "facebook/perturber", conditional=c.conditional) + preprocessed = preprocess_inputs( + sample=sample, + tokenizer=tokenizer, + tokenizer_kwargs=tokenizer_kwargs, + c=c, + input_template=input_template, + tokenize=False, + ) + return preprocessed + + +def get_span_at_idx(sequence: str, token_idx: List[int], c: TrainingConfig) -> str: + tokenizer, tokenizer_kwargs = get_tokenizer(c) + tokens = tokenizer(sequence, **tokenizer_kwargs) + return tokenizer.decode([tokens['input_ids'][i] for i in token_idx]) + + +def test_prepare_inputs_conditional(): + c = TrainingConfig( + model_name="hf-internal-testing/tiny-random-bart", + debug=True, + max_length=64, + conditional=True, + ) + preprocessed = get_preprocessed_sample(c) + span = get_span_at_idx( + sequence=preprocessed["perturbed"], + token_idx=preprocessed["perturbed_idx"], + c=c, + ) + assert span == perturbed_span + + +def test_prepare_inputs_unconditional(): + c = TrainingConfig( + model_name="hf-internal-testing/tiny-random-bart", + debug=True, + max_length=64, + conditional=False, + ) + preprocessed = get_preprocessed_sample(c) + assert get_span_at_idx( + sequence=preprocessed["perturbed"], + token_idx=preprocessed["perturbed_idx"], + c=c, + ) == perturbed_span + + assert get_span_at_idx( + sequence=preprocessed["perturbed"], + token_idx=preprocessed["word_idx"], + c=c, + ) == selected_word_span + + assert get_span_at_idx( + sequence=preprocessed["perturbed"], + token_idx=preprocessed["attribute_idx"], + c=c, + ) == target_attribute_span diff --git a/tests/test_unconditional_model.py b/tests/test_unconditional_model.py new file mode 100644 index 0000000..3e07f85 --- /dev/null +++ b/tests/test_unconditional_model.py @@ -0,0 +1,34 @@ +from perturbers import Perturber +from perturbers.data.panda_dict import GENDER_ATTRIBUTES, ALL_ATTRIBUTES, attribute_to_token + +UNPERTURBED = "Jack was passionate about rock climbing and his love for the sport was infectious to all men around him." +PERTURBED_SMALL = "Mary was passionate about rock climbing and her love for the sport was infectious to all men around her." +PERTURBED_BASE = "Jacqueline was passionate about rock climbing and her love for the sport was infectious to all men around her." + + +def test_small_perturber_model(): + model = Perturber("fairnlp/unconditional-perturber-small") + + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") + assert perturbed == PERTURBED_SMALL + + +def test_base_perturber_model(): + model = Perturber("fairnlp/unconditional-perturber-base") + + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") + assert perturbed == PERTURBED_BASE + + +def test_get_attribute_probabilities(): + model = Perturber("fairnlp/unconditional-perturber-small") + gender_token_map = {a: attribute_to_token(a) for a in GENDER_ATTRIBUTES} + male_probabilities = model.get_attribute_probabilities("He", attributes=gender_token_map) + female_probabilities = model.get_attribute_probabilities("She", attributes=gender_token_map) + assert male_probabilities["woman"] > female_probabilities["woman"] + assert female_probabilities["man"] > male_probabilities["man"] + assert len(male_probabilities) == len(gender_token_map) + + all_probabilities = model.get_attribute_probabilities("") + assert sum(list(all_probabilities.values())) - 1 < 1e-6 + assert len(all_probabilities) == len(ALL_ATTRIBUTES) diff --git a/train_scripts/train_unconditional_perturber_base.py b/train_scripts/train_unconditional_perturber_base.py new file mode 100644 index 0000000..25164e0 --- /dev/null +++ b/train_scripts/train_unconditional_perturber_base.py @@ -0,0 +1,31 @@ +from datetime import datetime + +from perturbers.training.core import train_perturber +from perturbers.training.utils import TrainingConfig + + +def main(): + config = TrainingConfig( + conditional=False, + model_name="facebook/bart-base", + train_steps=20000, + val_steps=1000, + use_wandb=True, + use_gpu=True, + max_length=512, + train_batch_size=16, + test_batch_size=16, + accumulate_grad_batches=4, + es_patience=5, + version=f"unconditional-perturber-base-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + learning_rate=1e-5, + push_to_hub=True, + hub_repo_id="fairnlp/unconditional-perturber-base", + num_workers=7, + ) + + train_perturber(config) + + +if __name__ == "__main__": + main() diff --git a/train_scripts/train_unconditional_perturber_small.py b/train_scripts/train_unconditional_perturber_small.py new file mode 100644 index 0000000..6c0a40a --- /dev/null +++ b/train_scripts/train_unconditional_perturber_small.py @@ -0,0 +1,31 @@ +from datetime import datetime + +from perturbers.training.core import train_perturber +from perturbers.training.utils import TrainingConfig + + +def main(): + config = TrainingConfig( + conditional=False, + model_name="lucadiliello/bart-small", + train_steps=20000, + val_steps=1000, + use_wandb=True, + use_gpu=True, + max_length=512, + train_batch_size=16, + test_batch_size=16, + accumulate_grad_batches=4, + es_patience=5, + version=f"unconditional-perturber-small-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + learning_rate=1e-5, + push_to_hub=True, + hub_repo_id="fairnlp/unconditional-perturber-small", + num_workers=7, + ) + + train_perturber(config) + + +if __name__ == "__main__": + main()