Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unconditional perturber #2

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,19 @@ approaching 1

**The perplexity and BLEU4 scores are those reported in the original paper and not measured via this codebase.

And unconditional perturber models:

| | Base model | Parameters | Perplexity | Perplexity (perturbed idx) | Perplexity (word) | Perplexity (attribute) |
|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------|------------|------------|----------------------------|-------------------|------------------------|
| [unconditional-perturber-small](https://huggingface.co/fairnlp/unconditional-perturber-small) | [bart-small](https://huggingface.co/lucadiliello/bart-small) | 70m | 1.101 | 4.265 | 5.262 | 5.635 |
| [unconditional-perturber-base](https://huggingface.co/fairnlp/unconditional-perturber-base) | [bart-base](https://huggingface.co/facebook/bart-base) | 139m | 1.082 | 2.897 | 4.537 | 5.328 |

# Roadmap

- [x] Add default perturber model
- [x] Pretrain small and medium perturber models
- [ ] Train model to identify target words and attributes
- [ ] Add training of unconditional perturber models (i.e. only get a sentence, no target word/attribute)
- [x] Add training of unconditional perturber models (i.e. only get a sentence, no target word/attribute)
- [ ] Add
- [ ] Add self-training by pretraining perturber base model (e.g. BART) on self-perturbed data

Other features could include
Expand Down
33 changes: 32 additions & 1 deletion perturbers/data/panda_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional, List

from datasets import concatenate_datasets, load_dataset

GENDER_ATTRIBUTES = {"man", "woman", "non-binary"}
RACE_ATTRIBUTES = {"black", "white", "hispanic", "native-american", "pacific-islander"}
RACE_ATTRIBUTES = {"black", "white", "asian", "hispanic", "native-american", "pacific-islander"}
AGE_ATTRIBUTES = {"child", "young", "middle-aged", "senior", "adult"}
ALL_ATTRIBUTES = GENDER_ATTRIBUTES | RACE_ATTRIBUTES | AGE_ATTRIBUTES

Expand All @@ -21,3 +23,32 @@ def get_panda_dict() -> dict[str, list[str]]:
perturbation_dict[word] = perturbation_dict.get(word, []) + [attribute]
sorted_dict = dict(reversed(sorted(perturbation_dict.items(), key=lambda item: len(item[1]))))
return sorted_dict


def get_attribute_tokens(attribute_set: Optional[set[str]] = None) -> List[str]:
"""
Creates specific attribute tokens

Args:
attribute_set: Set of attributes to be used for token generation. If None, all attributes are used.

Returns:
A list of attribute tokens.
"""
if attribute_set is None:
attribute_set = ALL_ATTRIBUTES

return [attribute_to_token(attr) for attr in attribute_set]


def attribute_to_token(attribute: str) -> str:
"""
Converts an attribute to a token

Args:
attribute: The attribute to be converted to a token

Returns:
The token corresponding to the attribute
"""
return f"<{attribute.upper().replace('-', '_').replace(' ', '_')}>"
225 changes: 191 additions & 34 deletions perturbers/modeling/perturber.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import logging
import random
from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional, Union, Literal, Tuple, List

import numpy as np
import torch
from transformers import BartForConditionalGeneration, AutoTokenizer, PreTrainedModel
from transformers.generation.configuration_utils import GenerationConfig

from perturbers.data.panda_dict import get_panda_dict, ALL_ATTRIBUTES
from perturbers.data.panda_dict import get_panda_dict, attribute_to_token, ALL_ATTRIBUTES


@dataclass
class PerturberConfig:
sep_token: str = '<SEP>'
pert_sep_token: str = '<PERT_SEP>'
max_length: int = 128
conditional: bool = True


@dataclass
class UnconditionalPerturberConfig(PerturberConfig):
conditional: bool = False


class Perturber:
Expand Down Expand Up @@ -45,16 +55,84 @@ def __init__(
self.model = BartForConditionalGeneration.from_pretrained(model_name)
self.config.sep_token = ","
self.config.pert_sep_token = "<PERT_SEP>"
if config is None and "unconditional" in model_name:
logging.info("Inferring unconditional perturber from model name")
self.config.conditional = False

self.attribute_to_token = {a: attribute_to_token(a) for a in ALL_ATTRIBUTES}
self.token_to_attribute = {t: a for a, t in self.attribute_to_token.items()}

self.model.config.max_length = self.config.max_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
self.tokenizer.add_tokens([self.config.pert_sep_token], special_tokens=True)
if self.model.model.get_input_embeddings().num_embeddings != len(self.tokenizer):
logging.warning("Number of tokens in tokenizer does not match number of tokens in model. The model "
"embeddings will be resized by adding random weights.")
self.model.model.resize_token_embeddings(len(self.tokenizer))

self.panda_dict = get_panda_dict()
self.input_template = PerturberTemplate(sep=self.config.sep_token, pert_sep=self.config.pert_sep_token,
original=model_name == "facebook/perturber")
original=model_name == "facebook/perturber",
conditional=self.config.conditional)

def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs=None) -> str:
def get_attribute_probabilities(self, input_txt: str, softmax: bool = True,
attributes: Optional[dict[str, str]] = None) -> dict[str, float]:
# TODO add option to normalize for prior, e.g. p(man) < p(woman) due to PANDA imbalance
"""
Given a certain input string, get a conditional probability distribution over a specified set of attributes.

Args:
input_txt: string to base the probability distribution on.

softmax: if True, apply a softmax over the logits, else return the likelihood values.

attributes: mapping of attributes to tokens for which to get probabilities. Default is all attributes.
"""
if attributes is None:
attributes = self.attribute_to_token

if self.config.conditional:
raise RuntimeError("Attribute classification is not possible for conditional perturber models")
decoder_inputs = self.tokenizer(self.tokenizer.bos_token, return_tensors='pt', add_special_tokens=False)
logits = self.model(
**self.tokenizer(input_txt, return_tensors='pt'),
decoder_input_ids=decoder_inputs.data["input_ids"],
decoder_attention_mask=decoder_inputs.data["attention_mask"],
).logits[0, -1]
assert len(logits.shape) == 1 and logits.shape[0] == len(self.tokenizer.vocab)

if softmax:
token_logits = torch.tensor([logits[self.tokenizer.vocab[t]] for t in attributes.values()])
softmax_probs = torch.softmax(token_logits, dim=0)
return {a: float(prob) for a, prob in zip(attributes.keys(), softmax_probs)}
else:
return {a: float(logits[self.tokenizer.vocab[t]].exp()) for a, t in attributes.items()}

def generate_conditions(self, input_txt: str, n_permutations: int, tokenizer_kwargs: dict) -> List[tuple[str, str]]:
if self.config.conditional:
raise RuntimeError("Attribute classification is not possible for conditional perturber models")
gen_config = GenerationConfig.from_model_config(self.model.config)
gen_config.update(eos_token_id=self.tokenizer.vocab[self.config.pert_sep_token])
generation = self.model.generate(
**self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs),
generation_config=gen_config,
max_new_tokens=self.model.config.max_length,
num_return_sequences=n_permutations,
num_beams=n_permutations,
)
attribute_tokens = generation[:, 2]
target_tokens = generation[:, 3:]

# Hack to prevent double brackets from InputTemplate
attributes = [self.token_to_attribute.get(a) for a in self.tokenizer.batch_decode(attribute_tokens,
max_new_tokens=self.model.config.max_length)]
target_words = [w.lstrip() for w in self.tokenizer.batch_decode(target_tokens, skip_special_tokens=True,
max_new_tokens=self.model.config.max_length)]
# Filter attribute hallucinations (unlikely)
return [(w, a) for idx, (w, a) in enumerate(zip(target_words, attributes)) if a is not None]

def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs: Optional[dict] = None,
generate_kwargs: Optional[dict] = None) -> Tuple[str, float]:
"""
Generates a perturbed version of the input text.

Expand All @@ -67,59 +145,115 @@ def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenize

tokenizer_kwargs: Additional keyword arguments to be passed to the tokenizer

generate_kwargs: Additional keyword arguments to be passed to the generate method

Returns:
Perturbed version of the input text
Perturbed version of the input text along with the average token probability
"""

if tokenizer_kwargs is None:
tokenizer_kwargs = {}
if attribute and attribute not in ALL_ATTRIBUTES:
raise ValueError(f"Attribute {attribute} not in {ALL_ATTRIBUTES}")
input_txt = self.input_template(input_txt, word, attribute)
output_tokens = self.model.generate(**self.tokenizer(input_txt, return_tensors='pt'), **tokenizer_kwargs)
return self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=True,
max_new_tokens=self.model.config.max_length
)[0].lstrip()

def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None, retry_unchanged=False
) -> Union[str, NotImplementedError]:
if generate_kwargs is None:
generate_kwargs = {}
generate_kwargs["return_dict_in_generate"] = True
generate_kwargs["output_scores"] = True

# Validate the attribute -- generated attribute is validated after generation
if self.config.conditional and attribute and attribute not in self.attribute_to_token:
raise ValueError(f"Attribute {attribute} not in {self.attribute_to_token.keys()}")

if self.config.conditional:
input_txt = self.input_template(input_txt, word, attribute)
tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs)
outputs = self.model.generate(**tokens, **generate_kwargs)
else:
prefix = self.tokenizer.bos_token + self.input_template.get_sentence_prefix(word, attribute)
encoder_tokens = self.tokenizer(input_txt, return_tensors='pt', **tokenizer_kwargs)
decoder_tokens = self.tokenizer(prefix, return_tensors='pt', add_special_tokens=False, **tokenizer_kwargs)
outputs = self.model.generate(
input_ids=encoder_tokens.data["input_ids"],
attention_mask=encoder_tokens.data["attention_mask"],
decoder_input_ids=decoder_tokens.data["input_ids"],
decoder_attention_mask=decoder_tokens.data["attention_mask"],
**generate_kwargs
)
output_string = self._decode_generation(outputs)
probabilities = self.model.compute_transition_scores(outputs.sequences, outputs.scores).exp()
return output_string, float(probabilities.mean())

def _decode_generation(self, outputs):
if self.config.conditional:
decode_tokens = outputs.sequences
else:
output_string = self.tokenizer.decode(
outputs.sequences[0], skip_special_tokens=False, max_new_tokens=self.model.config.max_length
)
output_string = self.config.pert_sep_token.join(output_string.split(self.config.pert_sep_token)[1:])
decode_tokens = self.tokenizer(output_string, return_tensors='pt').input_ids
output_string = self.tokenizer.decode(
decode_tokens[0], skip_special_tokens=True, max_new_tokens=self.model.config.max_length
)
return output_string.lstrip() # Remove trailing space from tokenization

def __call__(self, input_txt: str, mode: Optional[Literal['word_list', 'highest_prob', 'classify']] = None,
tokenizer_kwargs: Optional[dict] = None, generate_kwargs: Optional[dict] = None,
n_perturbations: int = 1, early_stopping: bool = False) -> Union[str, ValueError]:
"""
Perturbs the input text using the specified mode and returns the perturbed text. No target word or attribute
needs to be specified for this method.

Args:
input_txt: The input text to be perturbed

mode: The mode to be used for perturbation. Currently, only 'word_list' is supported
mode: The mode to be used for perturbation

tokenizer_kwargs: Additional keyword arguments to be passed to the tokenizer

retry_unchanged: If True, perturbation is retried with different target words/attributes until the output is
different from the input
generate_kwargs: Additional keyword arguments to be passed to the generate method

n_perturbations: The number of perturbations to be generated

early_stopping: If True, the first perturbation that differs from the input is returned.

Returns:
Perturbed version of the input text
"""
# Input validation
if tokenizer_kwargs is None:
tokenizer_kwargs = {}

if mode == 'highest_prob':
raise NotImplementedError # TODO
elif mode == 'word_list':
if generate_kwargs is None:
generate_kwargs = {}
if n_perturbations < 1:
raise ValueError(f"At least one perturbation needs to be tried. Received {n_perturbations} as argument.")
if mode is not None and mode not in {'word_list', 'classify'}:
raise ValueError(f"Mode {mode} is invalid. Please choose from 'word_list' or 'classify'.")
if mode is None:
mode = 'word_list' if self.config.conditional else 'classify'
elif mode == 'classify' and self.config.conditional:
raise ValueError("Conditional perturber models are not trained for attribute classification")

# Get perturbation targets and words
if mode == 'word_list':
targets = [w for w in input_txt.split(" ") if w in self.panda_dict]
perturbations = [(t, perturbed) for t in targets for perturbed in self.panda_dict[t]]
random.shuffle(perturbations)
for word, attribute in perturbations:
generated_txt = self.generate(input_txt, word=word, attribute=attribute,
tokenizer_kwargs=tokenizer_kwargs)
if generated_txt != input_txt or not retry_unchanged:
elif mode == 'classify':
perturbations = self.generate_conditions(input_txt, n_perturbations, tokenizer_kwargs)
random.shuffle(perturbations)
perturbations = perturbations[:n_perturbations]
texts, probabilities = [], []

for word, attribute in perturbations:
generated_txt, probability = self.generate(input_txt, word=word, attribute=attribute,
tokenizer_kwargs=tokenizer_kwargs,
generate_kwargs=generate_kwargs)
if generated_txt != input_txt:
if early_stopping:
return generated_txt
else:
raise NotImplementedError
else:
texts.append(generated_txt)
probabilities.append(probability)

return input_txt
return texts[np.argmax(probabilities)] if texts else input_txt


class PerturberTemplate:
Expand All @@ -130,9 +264,32 @@ class PerturberTemplate:
their occurrences in the input text.
"""

def __init__(self, sep: str = ",", pert_sep: str = "<PERT_SEP>", original: bool = False) -> None:
def __init__(self, sep: str = ",", pert_sep: str = "<PERT_SEP>",
original: bool = False, conditional: bool = True) -> None:
self.sep = sep
self.conditional = conditional
self.pert_sep = pert_sep if not original else f" {pert_sep}"

def __call__(self, input_txt: str, word: str = "", attribute: str = "") -> str:
return f"{word}{self.sep} {attribute}{self.pert_sep} {input_txt}"
if not self.conditional:
return f"{attribute_to_token(attribute)} {word}{self.pert_sep} {input_txt}"
else:
return f"{word}{self.sep} {attribute}{self.pert_sep} {input_txt}"

def get_sentence_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]:
if not self.conditional:
return f"{attribute_to_token(attribute)} {word}{self.pert_sep}"
else:
return ValueError("Sentence prefix not available for conditional perturber")

def get_attribute_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]:
if not self.conditional:
return ""
else:
return ValueError("Attribute prefix not available for conditional perturber")

def get_word_prefix(self, word: str = "", attribute: str = "") -> Union[str, ValueError]:
if not self.conditional:
return f"{attribute_to_token(attribute)}"
else:
return ValueError("Word prefix not available for conditional perturber")
Loading
Loading