From 6c9f327de785dcbe4a133ead7651d6afe9af9ada Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Mon, 18 Mar 2024 23:40:05 +0100 Subject: [PATCH 01/16] Conditional perturber implementation --- perturbers/data/panda_dict.py | 31 +++ perturbers/modeling/perturber.py | 185 +++++++++++++++--- perturbers/training/core.py | 113 ++++++++--- perturbers/training/utils.py | 1 + tests/test_generate.py | 19 ++ tests/test_prepare_inputs.py | 83 ++++++++ tests/test_unconditional_model.py | 20 ++ .../train_unconditional_perturber_base.py | 31 +++ .../train_unconditional_perturber_small.py | 31 +++ 9 files changed, 459 insertions(+), 55 deletions(-) create mode 100644 tests/test_generate.py create mode 100644 tests/test_prepare_inputs.py create mode 100644 tests/test_unconditional_model.py create mode 100644 train_scripts/train_unconditional_perturber_base.py create mode 100644 train_scripts/train_unconditional_perturber_small.py diff --git a/perturbers/data/panda_dict.py b/perturbers/data/panda_dict.py index 79c469c..636ecf2 100644 --- a/perturbers/data/panda_dict.py +++ b/perturbers/data/panda_dict.py @@ -1,3 +1,5 @@ +from typing import Optional, List + from datasets import concatenate_datasets, load_dataset GENDER_ATTRIBUTES = {"man", "woman", "non-binary"} @@ -21,3 +23,32 @@ def get_panda_dict() -> dict[str, list[str]]: perturbation_dict[word] = perturbation_dict.get(word, []) + [attribute] sorted_dict = dict(reversed(sorted(perturbation_dict.items(), key=lambda item: len(item[1])))) return sorted_dict + + +def get_attribute_tokens(attribute_set: Optional[set[str]] = None) -> List[str]: + """ + Creates specific attribute tokens + + Args: + attribute_set: Set of attributes to be used for token generation. If None, all attributes are used. + + Returns: + A list of attribute tokens. + """ + if attribute_set is None: + attribute_set = ALL_ATTRIBUTES + + return [attribute_to_token(attr) for attr in attribute_set] + + +def attribute_to_token(attribute: str) -> str: + """ + Converts an attribute to a token + + Args: + attribute: The attribute to be converted to a token + + Returns: + The token corresponding to the attribute + """ + return f"<{attribute.upper().replace('-', '_').replace(' ', '_')}>" diff --git a/perturbers/modeling/perturber.py b/perturbers/modeling/perturber.py index 5b07bd4..f4397b1 100644 --- a/perturbers/modeling/perturber.py +++ b/perturbers/modeling/perturber.py @@ -1,10 +1,13 @@ +import logging import random from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, Literal, Tuple, List +import numpy as np from transformers import BartForConditionalGeneration, AutoTokenizer, PreTrainedModel +from transformers.generation.configuration_utils import GenerationConfig -from perturbers.data.panda_dict import get_panda_dict, ALL_ATTRIBUTES +from perturbers.data.panda_dict import get_panda_dict, attribute_to_token, ALL_ATTRIBUTES @dataclass @@ -12,6 +15,12 @@ class PerturberConfig: sep_token: str = '' pert_sep_token: str = '' max_length: int = 128 + conditional: bool = True + + +@dataclass +class UnconditionalPerturberConfig(PerturberConfig): + conditional: bool = False class Perturber: @@ -45,16 +54,54 @@ def __init__( self.model = BartForConditionalGeneration.from_pretrained(model_name) self.config.sep_token = "," self.config.pert_sep_token = "" + if config is None and "unconditional" in model_name: + logging.info("Inferring unconditional perturber from model name") + self.config.conditional = False self.model.config.max_length = self.config.max_length self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) self.tokenizer.add_tokens([self.config.pert_sep_token], special_tokens=True) + if self.model.model.get_input_embeddings().num_embeddings != len(self.tokenizer): + logging.warning("Number of tokens in tokenizer does not match number of tokens in model. The model " + "embeddings will be resized by adding random weights.") + self.model.model.resize_token_embeddings(len(self.tokenizer)) self.panda_dict = get_panda_dict() self.input_template = PerturberTemplate(sep=self.config.sep_token, pert_sep=self.config.pert_sep_token, - original=model_name == "facebook/perturber") - - def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs=None) -> str: + original=model_name == "facebook/perturber", + conditional=self.config.conditional) + + def get_attribute_probabilities(self, input_txt: str): + if self.config.conditional: + raise RuntimeError("Attribute classification is not possible for conditional perturber models") + # TODO unconditional perturber methods for classifying the attribute and + attribute_tokens = [] + + def generate_conditions(self, input_txt: str, n_permutations: int, tokenizer_kwargs: dict) -> List[tuple[str, str]]: + if self.config.conditional: + raise RuntimeError("Attribute classification is not possible for conditional perturber models") + gen_config = GenerationConfig.from_model_config(self.model.config) + gen_config.update(eos_token_id=self.tokenizer.vocab[self.config.pert_sep_token]) + generation = self.model.generate( + **self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs), + generation_config=gen_config, + max_new_tokens=self.model.config.max_length, + num_return_sequences=n_permutations, + num_beams=n_permutations, + ) + attribute_tokens = generation[:, 2] + target_tokens = generation[:, 3:] + + # Hack to prevent double brackets from InputTemplate + attributes = [a[1:-1] for a in self.tokenizer.batch_decode(attribute_tokens, + max_new_tokens=self.model.config.max_length)] + target_words = [w.lstrip() for w in self.tokenizer.batch_decode(target_tokens, skip_special_tokens=True, + max_new_tokens=self.model.config.max_length)] + # TODO filter for non-valid attributes - could theoretically be a hallucatination + return list(zip(target_words, attributes)) + + def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs: Optional[dict] = None, + generate_kwargs: Optional[dict] = None) -> Tuple[str, float]: """ Generates a perturbed version of the input text. @@ -67,24 +114,62 @@ def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenize tokenizer_kwargs: Additional keyword arguments to be passed to the tokenizer + generate_kwargs: Additional keyword arguments to be passed to the generate method + Returns: - Perturbed version of the input text + Perturbed version of the input text along with the average token probability """ if tokenizer_kwargs is None: tokenizer_kwargs = {} - if attribute and attribute not in ALL_ATTRIBUTES: + if generate_kwargs is None: + generate_kwargs = {} + generate_kwargs["return_dict_in_generate"] = True + generate_kwargs["output_scores"] = True + if self.config.conditional and attribute and attribute not in ALL_ATTRIBUTES: # TODO validation for unconditional raise ValueError(f"Attribute {attribute} not in {ALL_ATTRIBUTES}") - input_txt = self.input_template(input_txt, word, attribute) - output_tokens = self.model.generate(**self.tokenizer(input_txt, return_tensors='pt'), **tokenizer_kwargs) - return self.tokenizer.batch_decode( - output_tokens, + if self.config.conditional: + input_txt = self.input_template(input_txt, word, attribute) + tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs) + # TODO also check BOS tokenization here + outputs = self.model.generate(**tokens, **generate_kwargs) + else: + prefix = self.tokenizer.bos_token + self.input_template.get_sentence_prefix(word, attribute) + encoder_tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs) + decoder_tokens = self.tokenizer(prefix, return_tensors='pt', add_special_tokens=False, **tokenizer_kwargs) + # TODO gen config with disabled BOS ? + outputs = self.model.generate( + input_ids=encoder_tokens.data["input_ids"], + attention_mask=encoder_tokens.data["attention_mask"], + decoder_input_ids=decoder_tokens.data["input_ids"], + decoder_attention_mask=decoder_tokens.data["attention_mask"], + **generate_kwargs + ) + output_string = self._decode_generation(outputs) + probabilities = self.model.compute_transition_scores(outputs.sequences, outputs.scores).exp() + return output_string, float(probabilities.mean()) + + def _decode_generation(self, outputs): + if self.config.conditional: + decode_tokens = outputs.sequences + else: + output_string = self.tokenizer.batch_decode( # TODO this is ugly - can it be non-batched + outputs.sequences, + skip_special_tokens=False, + max_new_tokens=self.model.config.max_length + )[0] + output_string = self.config.pert_sep_token.join(output_string.split(self.config.pert_sep_token)[1:]) + decode_tokens = self.tokenizer(output_string, return_tensors='pt').input_ids + output_string = self.tokenizer.batch_decode( # TODO this is ugly - can it be non-batched + decode_tokens, skip_special_tokens=True, max_new_tokens=self.model.config.max_length )[0].lstrip() + return output_string - def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None, retry_unchanged=False - ) -> Union[str, NotImplementedError]: + def __call__(self, input_txt: str, mode: Optional[Literal['word_list', 'highest_prob', 'classify']] = None, + tokenizer_kwargs: Optional[dict] = None, generate_kwargs: Optional[dict] = None, + n_perturbations: int = 1, early_stopping: bool = False) -> Union[str, ValueError]: """ Perturbs the input text using the specified mode and returns the perturbed text. No target word or attribute needs to be specified for this method. @@ -92,34 +177,51 @@ def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None, retry_unc Args: input_txt: The input text to be perturbed - mode: The mode to be used for perturbation. Currently, only 'word_list' is supported + mode: The mode to be used for perturbation tokenizer_kwargs: Additional keyword arguments to be passed to the tokenizer - retry_unchanged: If True, perturbation is retried with different target words/attributes until the output is - different from the input + generate_kwargs: Additional keyword arguments to be passed to the generate method + + n_perturbations: The number of perturbations to be generated + + early_stopping: If True, the first perturbation that differs from the input is returned. Returns: Perturbed version of the input text """ if tokenizer_kwargs is None: tokenizer_kwargs = {} - - if mode == 'highest_prob': - raise NotImplementedError # TODO - elif mode == 'word_list': + if generate_kwargs is None: + generate_kwargs = {} + if n_perturbations < 1: + raise ValueError(f"At least one perturbation needs to be tried. Received {n_perturbations} as argument.") + if mode is not None and mode not in {'word_list', 'classify'}: # TODO rename these + raise ValueError(f"Mode {mode} is invalid. Please choose from 'highest_prob, 'word_list' or 'classify'.") + if mode is None: + mode = 'word_list' if self.config.conditional else 'classify' + + if mode == 'word_list': targets = [w for w in input_txt.split(" ") if w in self.panda_dict] perturbations = [(t, perturbed) for t in targets for perturbed in self.panda_dict[t]] - random.shuffle(perturbations) - for word, attribute in perturbations: - generated_txt = self.generate(input_txt, word=word, attribute=attribute, - tokenizer_kwargs=tokenizer_kwargs) - if generated_txt != input_txt or not retry_unchanged: + elif mode == 'classify': + perturbations = self.generate_conditions(input_txt, n_perturbations, tokenizer_kwargs) + random.shuffle(perturbations) + perturbations = perturbations[:n_perturbations] + texts, probabilities = [], [] + + for word, attribute in perturbations: + generated_txt, probability = self.generate(input_txt, word=word, attribute=attribute, + tokenizer_kwargs=tokenizer_kwargs, + generate_kwargs=generate_kwargs) + if generated_txt != input_txt: + if early_stopping: return generated_txt - else: - raise NotImplementedError + else: + texts.append(generated_txt) + probabilities.append(probability) - return input_txt + return texts[np.argmax(probabilities)] if texts else input_txt class PerturberTemplate: @@ -130,9 +232,32 @@ class PerturberTemplate: their occurrences in the input text. """ - def __init__(self, sep: str = ",", pert_sep: str = "", original: bool = False) -> None: + def __init__(self, sep: str = ",", pert_sep: str = "", + original: bool = False, conditional: bool = True) -> None: self.sep = sep + self.conditional = conditional self.pert_sep = pert_sep if not original else f" {pert_sep}" def __call__(self, input_txt: str, word: str = "", attribute: str = "") -> str: - return f"{word}{self.sep} {attribute}{self.pert_sep} {input_txt}" + if not self.conditional: + return f"{attribute_to_token(attribute)} {word}{self.pert_sep} {input_txt}" + else: + return f"{word}{self.sep} {attribute}{self.pert_sep} {input_txt}" + + def get_sentence_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]: + if not self.conditional: + return f"{attribute_to_token(attribute)} {word}{self.pert_sep}" + else: + return ValueError("Sentence prefix not available for conditional perturber") + + def get_attribute_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]: + if not self.conditional: + return "" + else: + return ValueError("Attribute prefix not available for conditional perturber") + + def get_word_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]: + if not self.conditional: + return f"{attribute_to_token(attribute)}" + else: + return ValueError("Word prefix not available for conditional perturber") diff --git a/perturbers/training/core.py b/perturbers/training/core.py index 13147e9..2e23925 100644 --- a/perturbers/training/core.py +++ b/perturbers/training/core.py @@ -13,6 +13,7 @@ from transformers import AutoTokenizer, get_linear_schedule_with_warmup from transformers import PreTrainedTokenizerBase, PreTrainedModel +from perturbers.data.panda_dict import get_attribute_tokens from perturbers.modeling.perturber import PerturberTemplate from perturbers.training.utils import TrainingConfig, get_diff_indices @@ -37,15 +38,18 @@ def __init__(self, c: TrainingConfig, tokenizer: PreTrainedTokenizerBase) -> Non self.test_batch_size = c.test_batch_size self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) - self.train_metrics = self.get_metric_dict("train") - self.val_metrics = self.get_metric_dict("val") - self.test_metrics = self.get_metric_dict("test") + self.train_metrics = self.get_metric_dict(c, "train") + self.val_metrics = self.get_metric_dict(c, "val") + self.test_metrics = self.get_metric_dict(c, "test") - def get_metric_dict(self, split: str) -> dict[str, torch.nn.Module]: + def get_metric_dict(self, c: TrainingConfig, split: str) -> dict[str, torch.nn.Module]: metrics = { f'{split}_ppl': Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device), f'{split}_ppl_perturbed': Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device), } + if not c.conditional: + metrics[f'{split}_ppl_word'] = Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device) + metrics[f'{split}_ppl_attribute'] = Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device) if split == "test": metrics[f'{split}_bleu4'] = BLEUScore(n_gram=4).to(self._device) return metrics @@ -64,8 +68,14 @@ def update_metrics( target=[[_] for _ in self.tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)], ) elif "ppl" in metric_key: - if "perturbed" in metric_key: + idx = None + if metric_key.endswith("perturbed"): idx = batch["perturbed_idx"] + elif metric_key.endswith("word"): + idx = batch["word_idx"] + elif metric_key.endswith("attribute"): + idx = batch["attribute_idx"] + if idx is not None: value = metric(preds=outputs[idx].unsqueeze(0), target=batch['labels'][idx].unsqueeze(0)) else: value = metric(preds=outputs, target=batch['labels']) @@ -112,10 +122,7 @@ def forward(self, batch: dict) -> Tuple[torch.Tensor, torch.Tensor]: return outputs.logits, outputs.loss def generate(self, batch: dict) -> List[str]: - generations = self.model.generate( - **{k: v for k, v in batch.items() if k in ["input_ids", "attention_mask"]}, - max_length=batch['input_ids'].shape[-1], - ) + generations = self.model.generate(**{k: v for k, v in batch.items() if k in ["input_ids", "attention_mask"]}) return self.tokenizer.batch_decode(generations, skip_special_tokens=True) def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[dict]]: @@ -152,29 +159,45 @@ def get_collate_fn(c: TrainingConfig, tokenizer: PreTrainedTokenizerBase, tokeni Returns: The collate function for the dataloaders """ - input_template = PerturberTemplate(sep=c.sep_token, pert_sep=c.pert_sep_token, - original=c.model_name == "facebook/perturber") def collate_fn(batch: List) -> dict: original, perturbed = [], [] perturbed_x, perturbed_y = [], [] + attribute_x, attribute_y = [], [] + word_x, word_y = [], [] for i, item in enumerate(batch): - perturbed.append(item['perturbed']) - original.append(input_template(item["original"], item["selected_word"], item["target_attribute"])) - idx = item["perturbed_idx"] - perturbed_x += [i] * len(idx) - perturbed_y += idx + original.append(item["original"]) + perturbed.append(item["perturbed"]) + + perturbed_idx = item["perturbed_idx"] + perturbed_x += [i] * len(perturbed_idx) + perturbed_y += perturbed_idx + + if not c.conditional: + word_idx = item["word_idx"] + word_x += [i] * len(word_idx) + word_y += word_idx + + attribute_idx = item["attribute_idx"] + attribute_x += [i] * len(attribute_idx) + attribute_y += attribute_idx original = tokenizer(original, return_tensors='pt', **tokenizer_kwargs) perturbed = tokenizer(perturbed, return_tensors='pt', **tokenizer_kwargs) - return { + return_dict = { "input_ids": original["input_ids"], "attention_mask": original["attention_mask"], "labels": perturbed["input_ids"], "perturbed_idx": (perturbed_x, perturbed_y), } + if not c.conditional: + return_dict["word_idx"] = (word_x, word_y) + return_dict["attribute_idx"] = (attribute_x, attribute_y) + + return return_dict + return collate_fn @@ -205,15 +228,42 @@ def get_callbacks(c: TrainingConfig) -> List[pl.callbacks.Callback]: ] -def add_indices(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenizer_kwargs: dict) -> dict: +def preprocess_inputs(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenizer_kwargs: dict, c: TrainingConfig, + input_template: PerturberTemplate) -> dict: """ Add the indices of the tokens that are different between the original and perturbed text to the sample dictionary. Function signature is intended to be used with the `map` method of the Hugging Face datasets library. """ - sample["perturbed_idx"] = get_diff_indices( + + idx = get_diff_indices( tokenizer(sample['original'], **tokenizer_kwargs).data['input_ids'], - tokenizer(sample['perturbed'], **tokenizer_kwargs).data['input_ids'] + tokenizer(sample['perturbed'], **tokenizer_kwargs).data['input_ids'], ) + + if c.conditional: + sample["perturbed_idx"] = idx + sample['original'] = input_template(sample["original"], sample["selected_word"], sample["target_attribute"]) + else: + + # Account for prefix in perturbed indices + sentence_prefix = input_template.get_sentence_prefix(sample["selected_word"], sample["target_attribute"]) + sentence_offset = len(tokenizer.tokenize(sentence_prefix)) + sample["perturbed_idx"] = [i + sentence_offset for i in idx] + sample["perturbed"] = input_template(sample["perturbed"], sample["selected_word"], sample["target_attribute"]) + + # Add indices for word and attribute + n_cls_tokens = len(tokenizer.tokenize(tokenizer.bos_token)) + word_prefix = input_template.get_word_prefix(sample["selected_word"], sample["target_attribute"]) + word_offset = len(tokenizer.tokenize(word_prefix)) + n_cls_tokens + sample["word_idx"] = [i + word_offset for i in range(len(tokenizer.tokenize(sample["selected_word"])))] + + attribute_prefix = input_template.get_attribute_prefix(sample["selected_word"], sample["target_attribute"]) + attribute_offset = len(tokenizer.tokenize(attribute_prefix)) + n_cls_tokens + sample["attribute_idx"] = [attribute_offset] # Only one attribute token + + for idx_key in ["perturbed_idx", "word_idx", "attribute_idx"]: + sample[idx_key] = [i for i in sample[idx_key] if i < c.max_length] + return sample @@ -227,13 +277,15 @@ def train_perturber(c: TrainingConfig) -> PreTrainedModel: c.train_steps = 10 c.val_steps = 5 c.accumulate_grad_batches = 1 + c.num_workers = 0 - tokenizer = AutoTokenizer.from_pretrained(c.model_name, add_prefix_space=True) - tokenizer.add_tokens([c.sep_token, c.pert_sep_token], special_tokens=True) - tokenizer_kwargs = {"padding": True, "truncation": True, "max_length": c.max_length} + tokenizer, tokenizer_kwargs = get_tokenizer(c) model = LightningWrapper(c, tokenizer) dataset = load_dataset(c.dataset_name) + input_template = PerturberTemplate(sep=c.sep_token, pert_sep=c.pert_sep_token, + original=c.model_name == "facebook/perturber", conditional=c.conditional) + train_ds = dataset["train"] val_ds = dataset["validation"] @@ -241,8 +293,9 @@ def train_perturber(c: TrainingConfig) -> PreTrainedModel: train_ds = train_ds.select(range(128)) val_ds = val_ds.select(range(128)) - train_ds = train_ds.map(lambda x: add_indices(x, tokenizer, tokenizer_kwargs), num_proc=max(c.num_workers, 1)) - val_ds = val_ds.map(lambda x: add_indices(x, tokenizer, tokenizer_kwargs), num_proc=max(c.num_workers, 1)) + map_fn = lambda x: preprocess_inputs(x, tokenizer, tokenizer_kwargs, c, input_template) + train_ds = train_ds.map(map_fn, num_proc=max(c.num_workers, 1)) + val_ds = val_ds.map(map_fn, num_proc=max(c.num_workers, 1)) collate_fn = get_collate_fn(c, tokenizer, tokenizer_kwargs) @@ -285,3 +338,13 @@ def train_perturber(c: TrainingConfig) -> PreTrainedModel: tokenizer.push_to_hub(c.hub_repo_id) return model.model + + +def get_tokenizer(c): + tokenizer = AutoTokenizer.from_pretrained(c.model_name, add_prefix_space=True) + new_tokens = [c.sep_token, c.pert_sep_token] + if not c.conditional: + new_tokens += get_attribute_tokens() + tokenizer.add_tokens(new_tokens, special_tokens=True) + tokenizer_kwargs = {"padding": True, "truncation": True, "max_length": c.max_length} + return tokenizer, tokenizer_kwargs diff --git a/perturbers/training/utils.py b/perturbers/training/utils.py index 18f161e..96489f6 100644 --- a/perturbers/training/utils.py +++ b/perturbers/training/utils.py @@ -6,6 +6,7 @@ @dataclass class TrainingConfig: + conditional: bool = True model_name: str = "facebook/bart-large" dataset_name: str = "facebook/panda" train_batch_size: int = 64 diff --git a/tests/test_generate.py b/tests/test_generate.py new file mode 100644 index 0000000..88a75cf --- /dev/null +++ b/tests/test_generate.py @@ -0,0 +1,19 @@ +from perturbers import Perturber + + +def test_word_list(): + model = Perturber("hf-internal-testing/tiny-random-bart") + model( + mode="word_list", + input_txt="a", + generate_kwargs={"max_new_tokens": 2} + ) + + +def test_highest_prob(): + model = Perturber("hf-internal-testing/tiny-random-bart") + model( + mode="highest_prob", + input_txt="a", + generate_kwargs={"max_new_tokens": 2} + ) diff --git a/tests/test_prepare_inputs.py b/tests/test_prepare_inputs.py new file mode 100644 index 0000000..d70d28d --- /dev/null +++ b/tests/test_prepare_inputs.py @@ -0,0 +1,83 @@ +from typing import List + +from perturbers.data.panda_dict import attribute_to_token +from perturbers.modeling.perturber import PerturberTemplate +from perturbers.training.core import preprocess_inputs, get_tokenizer +from perturbers.training.utils import TrainingConfig + +# mock data +sample = { + "original": "Perturbers are cool!", + "perturbed": "Perturbers are great!", + "selected_word": "cool", + "target_attribute": "non-binary", +} +perturbed_span = " great" +selected_word_span = " cool" +target_attribute_span = attribute_to_token("non-binary") + + +def get_preprocessed_sample(c: TrainingConfig) -> dict: + tokenizer, tokenizer_kwargs = get_tokenizer(c) + input_template = PerturberTemplate(sep=c.sep_token, pert_sep=c.pert_sep_token, + original=c.model_name == "facebook/perturber", conditional=c.conditional) + preprocessed = preprocess_inputs( + sample=sample, + tokenizer=tokenizer, + tokenizer_kwargs=tokenizer_kwargs, + c=c, + input_template=input_template, + ) + return preprocessed + + +def get_span_at_idx(sequence: str, token_idx: List[int], c: TrainingConfig) -> str: + tokenizer, tokenizer_kwargs = get_tokenizer(c) + tokens = tokenizer(sequence, **tokenizer_kwargs) + return tokenizer.decode([tokens['input_ids'][i] for i in token_idx]) + + +def test_prepare_inputs_conditional(): + c = TrainingConfig( + model_name="hf-internal-testing/tiny-random-bart", + debug=True, + max_length=64, + conditional=True, + ) + preprocessed = get_preprocessed_sample(c) + span = get_span_at_idx( + sequence=preprocessed["perturbed"], + token_idx=preprocessed["perturbed_idx"], + c=c, + ) + assert span == perturbed_span + + +def test_prepare_inputs_unconditional(): + c = TrainingConfig( + model_name="hf-internal-testing/tiny-random-bart", + debug=True, + max_length=64, + conditional=False, + ) + preprocessed = get_preprocessed_sample(c) + span = get_span_at_idx( + sequence=preprocessed["perturbed"], + token_idx=preprocessed["perturbed_idx"], + c=c, + ) + assert span == perturbed_span + + span = get_span_at_idx( + sequence=preprocessed["original"], + token_idx=preprocessed["word_idx"], + c=c, + ) + assert span == selected_word_span + + span = get_span_at_idx( + sequence=preprocessed["original"], + token_idx=preprocessed["attribute_idx"], + c=c, + ) + assert span == target_attribute_span diff --git a/tests/test_unconditional_model.py b/tests/test_unconditional_model.py new file mode 100644 index 0000000..e1a787f --- /dev/null +++ b/tests/test_unconditional_model.py @@ -0,0 +1,20 @@ +from perturbers import Perturber + +UNPERTURBED = "Jack was passionate about rock climbing and his love for the sport was infectious to all men around him." +PERTURBED_ORIGINAL = "Jane was passionate about rock climbing and her love for the sport was infectious to all men around her." +PERTURBED_SMALL = "Jack was passionate about rock climbing and her love for the sport was infectious to all men around her." +PERTURBED_BASE = "Jacqueline was passionate about rock climbing and her love for the sport was infectious to all men around her." + + +def test_small_perturber_model(): + model = Perturber("fairnlp/unconditional-perturber-small") + + perturbed = model.generate(UNPERTURBED, "Jack", "woman") + assert perturbed == PERTURBED_SMALL + + +def test_base_perturber_model(): + model = Perturber("fairnlp/unconditional-perturber-base") + + perturbed = model.generate(UNPERTURBED, "Jack", "woman") + assert perturbed == PERTURBED_BASE diff --git a/train_scripts/train_unconditional_perturber_base.py b/train_scripts/train_unconditional_perturber_base.py new file mode 100644 index 0000000..25164e0 --- /dev/null +++ b/train_scripts/train_unconditional_perturber_base.py @@ -0,0 +1,31 @@ +from datetime import datetime + +from perturbers.training.core import train_perturber +from perturbers.training.utils import TrainingConfig + + +def main(): + config = TrainingConfig( + conditional=False, + model_name="facebook/bart-base", + train_steps=20000, + val_steps=1000, + use_wandb=True, + use_gpu=True, + max_length=512, + train_batch_size=16, + test_batch_size=16, + accumulate_grad_batches=4, + es_patience=5, + version=f"unconditional-perturber-base-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + learning_rate=1e-5, + push_to_hub=True, + hub_repo_id="fairnlp/unconditional-perturber-base", + num_workers=7, + ) + + train_perturber(config) + + +if __name__ == "__main__": + main() diff --git a/train_scripts/train_unconditional_perturber_small.py b/train_scripts/train_unconditional_perturber_small.py new file mode 100644 index 0000000..6c0a40a --- /dev/null +++ b/train_scripts/train_unconditional_perturber_small.py @@ -0,0 +1,31 @@ +from datetime import datetime + +from perturbers.training.core import train_perturber +from perturbers.training.utils import TrainingConfig + + +def main(): + config = TrainingConfig( + conditional=False, + model_name="lucadiliello/bart-small", + train_steps=20000, + val_steps=1000, + use_wandb=True, + use_gpu=True, + max_length=512, + train_batch_size=16, + test_batch_size=16, + accumulate_grad_batches=4, + es_patience=5, + version=f"unconditional-perturber-small-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}", + learning_rate=1e-5, + push_to_hub=True, + hub_repo_id="fairnlp/unconditional-perturber-small", + num_workers=7, + ) + + train_perturber(config) + + +if __name__ == "__main__": + main() From 628a52963129497bd58d8bb62c21ac4e53821d26 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 18:34:05 +0100 Subject: [PATCH 02/16] Add attribute-to-token mapping & remove TODOs --- perturbers/modeling/perturber.py | 37 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/perturbers/modeling/perturber.py b/perturbers/modeling/perturber.py index f4397b1..c24dce3 100644 --- a/perturbers/modeling/perturber.py +++ b/perturbers/modeling/perturber.py @@ -58,6 +58,10 @@ def __init__( logging.info("Inferring unconditional perturber from model name") self.config.conditional = False + self.attribute_to_token = {a: attribute_to_token(a) for a in ALL_ATTRIBUTES} + self.token_to_attribute = {t: a for a, t in self.attribute_to_token.items()} + self.attribute_tokens = {attribute_to_token(a) for a in ALL_ATTRIBUTES} + self.model.config.max_length = self.config.max_length self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) self.tokenizer.add_tokens([self.config.pert_sep_token], special_tokens=True) @@ -93,12 +97,12 @@ def generate_conditions(self, input_txt: str, n_permutations: int, tokenizer_kwa target_tokens = generation[:, 3:] # Hack to prevent double brackets from InputTemplate - attributes = [a[1:-1] for a in self.tokenizer.batch_decode(attribute_tokens, - max_new_tokens=self.model.config.max_length)] + attributes = [self.token_to_attribute.get(a) for a in self.tokenizer.batch_decode(attribute_tokens, + max_new_tokens=self.model.config.max_length)] target_words = [w.lstrip() for w in self.tokenizer.batch_decode(target_tokens, skip_special_tokens=True, max_new_tokens=self.model.config.max_length)] - # TODO filter for non-valid attributes - could theoretically be a hallucatination - return list(zip(target_words, attributes)) + # Filter attribute hallucinations (unlikely) + return [(w, a) for idx, (w, a) in enumerate(zip(target_words, attributes)) if a is not None] def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs: Optional[dict] = None, generate_kwargs: Optional[dict] = None) -> Tuple[str, float]: @@ -126,18 +130,19 @@ def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenize generate_kwargs = {} generate_kwargs["return_dict_in_generate"] = True generate_kwargs["output_scores"] = True - if self.config.conditional and attribute and attribute not in ALL_ATTRIBUTES: # TODO validation for unconditional + + # Validate the attribute -- generated attribute is validated after generation + if self.config.conditional and attribute and attribute not in ALL_ATTRIBUTES: raise ValueError(f"Attribute {attribute} not in {ALL_ATTRIBUTES}") + if self.config.conditional: input_txt = self.input_template(input_txt, word, attribute) tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs) - # TODO also check BOS tokenization here outputs = self.model.generate(**tokens, **generate_kwargs) else: prefix = self.tokenizer.bos_token + self.input_template.get_sentence_prefix(word, attribute) encoder_tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs) decoder_tokens = self.tokenizer(prefix, return_tensors='pt', add_special_tokens=False, **tokenizer_kwargs) - # TODO gen config with disabled BOS ? outputs = self.model.generate( input_ids=encoder_tokens.data["input_ids"], attention_mask=encoder_tokens.data["attention_mask"], @@ -153,19 +158,15 @@ def _decode_generation(self, outputs): if self.config.conditional: decode_tokens = outputs.sequences else: - output_string = self.tokenizer.batch_decode( # TODO this is ugly - can it be non-batched - outputs.sequences, - skip_special_tokens=False, - max_new_tokens=self.model.config.max_length - )[0] + output_string = self.tokenizer.decode( + outputs.sequences[0], skip_special_tokens=False, max_new_tokens=self.model.config.max_length + ) output_string = self.config.pert_sep_token.join(output_string.split(self.config.pert_sep_token)[1:]) decode_tokens = self.tokenizer(output_string, return_tensors='pt').input_ids - output_string = self.tokenizer.batch_decode( # TODO this is ugly - can it be non-batched - decode_tokens, - skip_special_tokens=True, - max_new_tokens=self.model.config.max_length - )[0].lstrip() - return output_string + output_string = self.tokenizer.decode( + decode_tokens[0], skip_special_tokens=True, max_new_tokens=self.model.config.max_length + ) + return output_string.lstrip() # Remove trailing space from tokenization def __call__(self, input_txt: str, mode: Optional[Literal['word_list', 'highest_prob', 'classify']] = None, tokenizer_kwargs: Optional[dict] = None, generate_kwargs: Optional[dict] = None, From 4f11a012921c42af9553e9321c9f36f678900ba3 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 18:34:14 +0100 Subject: [PATCH 03/16] Update unconditional perturber tests --- tests/test_unconditional_model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_unconditional_model.py b/tests/test_unconditional_model.py index e1a787f..ed0982b 100644 --- a/tests/test_unconditional_model.py +++ b/tests/test_unconditional_model.py @@ -1,20 +1,19 @@ from perturbers import Perturber UNPERTURBED = "Jack was passionate about rock climbing and his love for the sport was infectious to all men around him." -PERTURBED_ORIGINAL = "Jane was passionate about rock climbing and her love for the sport was infectious to all men around her." -PERTURBED_SMALL = "Jack was passionate about rock climbing and her love for the sport was infectious to all men around her." -PERTURBED_BASE = "Jacqueline was passionate about rock climbing and her love for the sport was infectious to all men around her." +PERTURBED_SMALL = "Mary was passionate about rock climbing and her love for the sport was infectious to all men around her." +PERTURBED_BASE = "Jane was passionate about rock climbing and her love for the sport was infectious to all men around her." def test_small_perturber_model(): model = Perturber("fairnlp/unconditional-perturber-small") - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_SMALL def test_base_perturber_model(): model = Perturber("fairnlp/unconditional-perturber-base") - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_BASE From 6d18b9ded8693e09a87da63c53d68a9b605251cc Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 18:40:43 +0100 Subject: [PATCH 04/16] Update conditional perturber tests --- tests/test_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index f22540e..698a33c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,19 +9,19 @@ def test_perturber_model(): model = Perturber() - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_ORIGINAL def test_small_perturber_model(): model = Perturber("fairnlp/perturber-small") - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_SMALL def test_base_perturber_model(): model = Perturber("fairnlp/perturber-base") - perturbed = model.generate(UNPERTURBED, "Jack", "woman") + perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_BASE From b6d920aac4564ac0c97145da4a8aa0d68069b3e8 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 19:16:09 +0100 Subject: [PATCH 05/16] Remove old token set --- perturbers/modeling/perturber.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/perturbers/modeling/perturber.py b/perturbers/modeling/perturber.py index c24dce3..6842776 100644 --- a/perturbers/modeling/perturber.py +++ b/perturbers/modeling/perturber.py @@ -60,7 +60,6 @@ def __init__( self.attribute_to_token = {a: attribute_to_token(a) for a in ALL_ATTRIBUTES} self.token_to_attribute = {t: a for a, t in self.attribute_to_token.items()} - self.attribute_tokens = {attribute_to_token(a) for a in ALL_ATTRIBUTES} self.model.config.max_length = self.config.max_length self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) @@ -132,8 +131,8 @@ def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenize generate_kwargs["output_scores"] = True # Validate the attribute -- generated attribute is validated after generation - if self.config.conditional and attribute and attribute not in ALL_ATTRIBUTES: - raise ValueError(f"Attribute {attribute} not in {ALL_ATTRIBUTES}") + if self.config.conditional and attribute and attribute not in self.attribute_to_token: + raise ValueError(f"Attribute {attribute} not in {self.attribute_to_token.keys()}") if self.config.conditional: input_txt = self.input_template(input_txt, word, attribute) From f4bd0f840ebc0fea37f8ef68ea114ffa655a449a Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 19:19:28 +0100 Subject: [PATCH 06/16] Update input modes --- perturbers/modeling/perturber.py | 8 ++++++-- tests/test_generate.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/perturbers/modeling/perturber.py b/perturbers/modeling/perturber.py index 6842776..b750eca 100644 --- a/perturbers/modeling/perturber.py +++ b/perturbers/modeling/perturber.py @@ -190,17 +190,21 @@ def __call__(self, input_txt: str, mode: Optional[Literal['word_list', 'highest_ Returns: Perturbed version of the input text """ + # Input validation if tokenizer_kwargs is None: tokenizer_kwargs = {} if generate_kwargs is None: generate_kwargs = {} if n_perturbations < 1: raise ValueError(f"At least one perturbation needs to be tried. Received {n_perturbations} as argument.") - if mode is not None and mode not in {'word_list', 'classify'}: # TODO rename these - raise ValueError(f"Mode {mode} is invalid. Please choose from 'highest_prob, 'word_list' or 'classify'.") + if mode is not None and mode not in {'word_list', 'classify'}: + raise ValueError(f"Mode {mode} is invalid. Please choose from 'word_list' or 'classify'.") if mode is None: mode = 'word_list' if self.config.conditional else 'classify' + elif mode == 'classify' and self.config.conditional: + raise ValueError("Conditional perturber models are not trained for attribute classification") + # Get perturbation targets and words if mode == 'word_list': targets = [w for w in input_txt.split(" ") if w in self.panda_dict] perturbations = [(t, perturbed) for t in targets for perturbed in self.panda_dict[t]] diff --git a/tests/test_generate.py b/tests/test_generate.py index 88a75cf..a304da5 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -13,7 +13,7 @@ def test_word_list(): def test_highest_prob(): model = Perturber("hf-internal-testing/tiny-random-bart") model( - mode="highest_prob", + mode="classify", input_txt="a", generate_kwargs={"max_new_tokens": 2} ) From f43d41f880fac78a3b218a89e0011f82bcfd8ea9 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 19 Mar 2024 19:41:00 +0100 Subject: [PATCH 07/16] Update classify test --- tests/test_generate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_generate.py b/tests/test_generate.py index a304da5..bf5bee4 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -1,4 +1,5 @@ from perturbers import Perturber +from perturbers.modeling.perturber import PerturberConfig def test_word_list(): @@ -10,8 +11,11 @@ def test_word_list(): ) -def test_highest_prob(): - model = Perturber("hf-internal-testing/tiny-random-bart") +def test_classify(): + model = Perturber( + model="hf-internal-testing/tiny-random-bart", + config=PerturberConfig(conditional=False, max_length=8), + ) model( mode="classify", input_txt="a", From 684473691417db8454908c6eda6acf8efda3655f Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 20 Mar 2024 18:24:01 +0100 Subject: [PATCH 08/16] Update README --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index efb1144..e556ece 100644 --- a/README.md +++ b/README.md @@ -60,12 +60,19 @@ approaching 1 **The perplexity and BLEU4 scores are those reported in the original paper and not measured via this codebase. +And unconditional perturber models: + +| | Base model | Parameters | Perplexity | Perplexity (perturbed idx) | Perplexity (word) | Perplexity (attribute) | +|-----------------------------------|--------------------------------------------------------------|------------|------------|----------------------------|-------------------|------------------------| +| [unconditional-perturber-small]() | [bart-small](https://huggingface.co/lucadiliello/bart-small) | 70m | 1.101 | 4.366 | 5.220 | 5.592 | +| [unconditional-perturber-base]() | [bart-base](https://huggingface.co/facebook/bart-base) | 139m | 1.082 | 2.830 | 4.730 | 5.413 | + # Roadmap - [x] Add default perturber model - [x] Pretrain small and medium perturber models -- [ ] Train model to identify target words and attributes -- [ ] Add training of unconditional perturber models (i.e. only get a sentence, no target word/attribute) +- [x] Add training of unconditional perturber models (i.e. only get a sentence, no target word/attribute) +- [ ] Add - [ ] Add self-training by pretraining perturber base model (e.g. BART) on self-perturbed data Other features could include From 7165de386a6eef7dcc1ae40949710269c9b0b408 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 20 Mar 2024 18:24:08 +0100 Subject: [PATCH 09/16] Bump version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fb08e26..91e24c0 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ tests_require=TEST_REQUIREMENTS, extras_require=EXTRAS_REQUIREMENTS, url="https://github.com/FairNLP/perturbers", - version="0.0.0", + version="0.0.1", long_description=open('README.md').read(), long_description_content_type='text/markdown', zip_safe=False, From 999009c305a7a69ab3de40720f8a6e3e0134089a Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 20 Mar 2024 19:43:46 +0100 Subject: [PATCH 10/16] Attribute classification --- perturbers/modeling/perturber.py | 34 ++++++++++++++++++++++++++++--- tests/test_unconditional_model.py | 15 ++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/perturbers/modeling/perturber.py b/perturbers/modeling/perturber.py index b750eca..321c021 100644 --- a/perturbers/modeling/perturber.py +++ b/perturbers/modeling/perturber.py @@ -4,6 +4,7 @@ from typing import Optional, Union, Literal, Tuple, List import numpy as np +import torch from transformers import BartForConditionalGeneration, AutoTokenizer, PreTrainedModel from transformers.generation.configuration_utils import GenerationConfig @@ -74,11 +75,38 @@ def __init__( original=model_name == "facebook/perturber", conditional=self.config.conditional) - def get_attribute_probabilities(self, input_txt: str): + def get_attribute_probabilities(self, input_txt: str, softmax: bool = True, + attributes: Optional[dict[str, str]] = None) -> dict[str, float]: + # TODO add option to normalize for prior, e.g. p(man) < p(woman) due to PANDA imbalance + """ + Given a certain input string, get a conditional probability distribution over a specified set of attributes. + + Args: + input_txt: string to base the probability distribution on. + + softmax: if True, apply a softmax over the logits, else return the likelihood values. + + attributes: mapping of attributes to tokens for which to get probabilities. Default is all attributes. + """ + if attributes is None: + attributes = self.attribute_to_token + if self.config.conditional: raise RuntimeError("Attribute classification is not possible for conditional perturber models") - # TODO unconditional perturber methods for classifying the attribute and - attribute_tokens = [] + decoder_inputs = self.tokenizer(self.tokenizer.bos_token, return_tensors='pt', add_special_tokens=False) + logits = self.model( + **self.tokenizer(input_txt, return_tensors='pt'), + decoder_input_ids=decoder_inputs.data["input_ids"], + decoder_attention_mask=decoder_inputs.data["attention_mask"], + ).logits[0, -1] + assert len(logits.shape) == 1 and logits.shape[0] == len(self.tokenizer.vocab) + + if softmax: + token_logits = torch.tensor([logits[self.tokenizer.vocab[t]] for t in attributes.values()]) + softmax_probs = torch.softmax(token_logits, dim=0) + return {a: float(prob) for a, prob in zip(attributes.keys(), softmax_probs)} + else: + return {a: float(logits[self.tokenizer.vocab[t]].exp()) for a, t in attributes.items()} def generate_conditions(self, input_txt: str, n_permutations: int, tokenizer_kwargs: dict) -> List[tuple[str, str]]: if self.config.conditional: diff --git a/tests/test_unconditional_model.py b/tests/test_unconditional_model.py index ed0982b..e11b369 100644 --- a/tests/test_unconditional_model.py +++ b/tests/test_unconditional_model.py @@ -1,4 +1,5 @@ from perturbers import Perturber +from perturbers.data.panda_dict import GENDER_ATTRIBUTES, ALL_ATTRIBUTES, attribute_to_token UNPERTURBED = "Jack was passionate about rock climbing and his love for the sport was infectious to all men around him." PERTURBED_SMALL = "Mary was passionate about rock climbing and her love for the sport was infectious to all men around her." @@ -17,3 +18,17 @@ def test_base_perturber_model(): perturbed, probability = model.generate(UNPERTURBED, "Jack", "woman") assert perturbed == PERTURBED_BASE + + +def test_get_attribute_probabilities(): + model = Perturber("fairnlp/unconditional-perturber-small") + gender_token_map = {a: attribute_to_token(a) for a in GENDER_ATTRIBUTES} + male_probabilities = model.get_attribute_probabilities("He", attributes=gender_token_map) + female_probabilities = model.get_attribute_probabilities("She", attributes=gender_token_map) + assert male_probabilities["woman"] > female_probabilities["woman"] + assert female_probabilities["man"] > male_probabilities["man"] + assert len(male_probabilities) == len(gender_token_map) + + all_probabilities = model.get_attribute_probabilities("") + assert sum(list(all_probabilities.values())) - 1 < 1e-6 + assert len(all_probabilities) == len(ALL_ATTRIBUTES) From d53a49bfeb136e175911573ca290985f23b26d52 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 20 Mar 2024 20:05:06 +0100 Subject: [PATCH 11/16] Test token indices wrt. perturbed sentence --- tests/test_prepare_inputs.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/test_prepare_inputs.py b/tests/test_prepare_inputs.py index d70d28d..2ec25b3 100644 --- a/tests/test_prepare_inputs.py +++ b/tests/test_prepare_inputs.py @@ -61,23 +61,20 @@ def test_prepare_inputs_unconditional(): conditional=False, ) preprocessed = get_preprocessed_sample(c) - span = get_span_at_idx( + assert get_span_at_idx( sequence=preprocessed["perturbed"], token_idx=preprocessed["perturbed_idx"], c=c, - ) - assert span == perturbed_span + ) == perturbed_span - span = get_span_at_idx( - sequence=preprocessed["original"], + assert get_span_at_idx( + sequence=preprocessed["perturbed"], token_idx=preprocessed["word_idx"], c=c, - ) - assert span == selected_word_span + ) == selected_word_span - span = get_span_at_idx( - sequence=preprocessed["original"], + assert get_span_at_idx( + sequence=preprocessed["perturbed"], token_idx=preprocessed["attribute_idx"], c=c, - ) - assert span == target_attribute_span + ) == target_attribute_span From adf9cbbb2d1bd5887e16335e4d59f42ff4224925 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Wed, 20 Mar 2024 21:25:20 +0100 Subject: [PATCH 12/16] Fix incomplete race axis --- perturbers/data/panda_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perturbers/data/panda_dict.py b/perturbers/data/panda_dict.py index 636ecf2..19e3bf4 100644 --- a/perturbers/data/panda_dict.py +++ b/perturbers/data/panda_dict.py @@ -3,7 +3,7 @@ from datasets import concatenate_datasets, load_dataset GENDER_ATTRIBUTES = {"man", "woman", "non-binary"} -RACE_ATTRIBUTES = {"black", "white", "hispanic", "native-american", "pacific-islander"} +RACE_ATTRIBUTES = {"black", "white", "asian", "hispanic", "native-american", "pacific-islander"} AGE_ATTRIBUTES = {"child", "young", "middle-aged", "senior", "adult"} ALL_ATTRIBUTES = GENDER_ATTRIBUTES | RACE_ATTRIBUTES | AGE_ATTRIBUTES From 6fc693c1ce7d234bf2c1561c5caa2eaf72116f16 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 22 Mar 2024 18:07:44 +0100 Subject: [PATCH 13/16] Tokenize before training --- perturbers/training/core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/perturbers/training/core.py b/perturbers/training/core.py index 2e23925..04be2fe 100644 --- a/perturbers/training/core.py +++ b/perturbers/training/core.py @@ -11,6 +11,7 @@ from torchmetrics.text import Perplexity, BLEUScore from transformers import AutoModel, BartForConditionalGeneration # noqa 401 from transformers import AutoTokenizer, get_linear_schedule_with_warmup +from transformers import DataCollatorWithPadding from transformers import PreTrainedTokenizerBase, PreTrainedModel from perturbers.data.panda_dict import get_attribute_tokens @@ -159,6 +160,7 @@ def get_collate_fn(c: TrainingConfig, tokenizer: PreTrainedTokenizerBase, tokeni Returns: The collate function for the dataloaders """ + collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors='pt', padding=True) def collate_fn(batch: List) -> dict: original, perturbed = [], [] @@ -182,8 +184,8 @@ def collate_fn(batch: List) -> dict: attribute_x += [i] * len(attribute_idx) attribute_y += attribute_idx - original = tokenizer(original, return_tensors='pt', **tokenizer_kwargs) - perturbed = tokenizer(perturbed, return_tensors='pt', **tokenizer_kwargs) + original = collator(original) + perturbed = collator(perturbed) return_dict = { "input_ids": original["input_ids"], @@ -264,6 +266,9 @@ def preprocess_inputs(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenize for idx_key in ["perturbed_idx", "word_idx", "attribute_idx"]: sample[idx_key] = [i for i in sample[idx_key] if i < c.max_length] + sample['original'] = tokenizer(sample['original'], **tokenizer_kwargs) + sample['perturbed'] = tokenizer(sample['perturbed'], **tokenizer_kwargs) + return sample From d2332a0f4d8b59e5cb6685ef3f3440439779fa49 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 22 Mar 2024 18:08:13 +0100 Subject: [PATCH 14/16] Update retrained model performances --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e556ece..03f4936 100644 --- a/README.md +++ b/README.md @@ -62,10 +62,10 @@ approaching 1 And unconditional perturber models: -| | Base model | Parameters | Perplexity | Perplexity (perturbed idx) | Perplexity (word) | Perplexity (attribute) | -|-----------------------------------|--------------------------------------------------------------|------------|------------|----------------------------|-------------------|------------------------| -| [unconditional-perturber-small]() | [bart-small](https://huggingface.co/lucadiliello/bart-small) | 70m | 1.101 | 4.366 | 5.220 | 5.592 | -| [unconditional-perturber-base]() | [bart-base](https://huggingface.co/facebook/bart-base) | 139m | 1.082 | 2.830 | 4.730 | 5.413 | +| | Base model | Parameters | Perplexity | Perplexity (perturbed idx) | Perplexity (word) | Perplexity (attribute) | +|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------|------------|------------|----------------------------|-------------------|------------------------| +| [unconditional-perturber-small](https://huggingface.co/fairnlp/unconditional-perturber-small) | [bart-small](https://huggingface.co/lucadiliello/bart-small) | 70m | 1.101 | 4.265 | 5.262 | 5.635 | +| [unconditional-perturber-base](https://huggingface.co/fairnlp/unconditional-perturber-base) | [bart-base](https://huggingface.co/facebook/bart-base) | 139m | 1.082 | 2.897 | 4.537 | 5.328 | # Roadmap From ed7f29dfde22dc08b7097e0d710a3c7cb001f517 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 22 Mar 2024 22:43:21 +0100 Subject: [PATCH 15/16] Disable tokenization for template tests --- perturbers/training/core.py | 7 ++++--- tests/test_prepare_inputs.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/perturbers/training/core.py b/perturbers/training/core.py index 04be2fe..b8b724d 100644 --- a/perturbers/training/core.py +++ b/perturbers/training/core.py @@ -231,7 +231,7 @@ def get_callbacks(c: TrainingConfig) -> List[pl.callbacks.Callback]: def preprocess_inputs(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenizer_kwargs: dict, c: TrainingConfig, - input_template: PerturberTemplate) -> dict: + input_template: PerturberTemplate, tokenize: bool = True) -> dict: """ Add the indices of the tokens that are different between the original and perturbed text to the sample dictionary. Function signature is intended to be used with the `map` method of the Hugging Face datasets library. @@ -266,8 +266,9 @@ def preprocess_inputs(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenize for idx_key in ["perturbed_idx", "word_idx", "attribute_idx"]: sample[idx_key] = [i for i in sample[idx_key] if i < c.max_length] - sample['original'] = tokenizer(sample['original'], **tokenizer_kwargs) - sample['perturbed'] = tokenizer(sample['perturbed'], **tokenizer_kwargs) + if tokenize: + sample['original'] = tokenizer(sample['original'], **tokenizer_kwargs) + sample['perturbed'] = tokenizer(sample['perturbed'], **tokenizer_kwargs) return sample diff --git a/tests/test_prepare_inputs.py b/tests/test_prepare_inputs.py index 2ec25b3..690fb78 100644 --- a/tests/test_prepare_inputs.py +++ b/tests/test_prepare_inputs.py @@ -27,6 +27,7 @@ def get_preprocessed_sample(c: TrainingConfig) -> dict: tokenizer_kwargs=tokenizer_kwargs, c=c, input_template=input_template, + tokenize=False, ) return preprocessed From b4b49e150d04630d217339a4add6bc132805d41d Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Fri, 22 Mar 2024 22:45:02 +0100 Subject: [PATCH 16/16] Update perturbed sentence for newly retrained model --- tests/test_unconditional_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_unconditional_model.py b/tests/test_unconditional_model.py index e11b369..3e07f85 100644 --- a/tests/test_unconditional_model.py +++ b/tests/test_unconditional_model.py @@ -3,7 +3,7 @@ UNPERTURBED = "Jack was passionate about rock climbing and his love for the sport was infectious to all men around him." PERTURBED_SMALL = "Mary was passionate about rock climbing and her love for the sport was infectious to all men around her." -PERTURBED_BASE = "Jane was passionate about rock climbing and her love for the sport was infectious to all men around her." +PERTURBED_BASE = "Jacqueline was passionate about rock climbing and her love for the sport was infectious to all men around her." def test_small_perturber_model():