-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPerturber
224 lines (190 loc) · 10.3 KB
/
Perturber
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import spacy
import nltk
from nltk.corpus import wordnet, brown
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import random
import numpy as np
from sentence_transformers import SentenceTransformer
from collections import Counter
from annoy import AnnoyIndex # For efficient nearest neighbor search
# Download required NLTK data (do this only once, outside the class)
try:
nltk.data.find('corpora/wordnet')
except LookupError:
nltk.download('wordnet')
try:
nltk.data.find('corpora/wordnet_ic')
except LookupError:
nltk.download('wordnet_ic')
try:
nltk.data.find('corpora/brown')
except LookupError:
nltk.download('brown')
try:
nltk.data.find('taggers/averaged_perceptron_tagger') #Needed for POS tagging
except LookupError:
nltk.download('averaged_perceptron_tagger')
class LinguisticQuantumPerturber:
def __init__(self, model_name="distilbert-base-uncased", use_gpu=True):
self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
self.nlp = spacy.load("en_core_web_lg")
self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.sbert = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
# Build Annoy index for fast nearest neighbor search
self.embedding_dim = self.nlp.vocab.vectors_length
self.annoy_index = AnnoyIndex(self.embedding_dim, 'angular') # Angular distance for cosine similarity
self.word_to_index = {}
print("Building Annoy Index for Nearest Neighbors...")
for i, word in enumerate(self.nlp.vocab):
if word.has_vector and word.is_alpha and word.lower_ not in self.word_to_index:
self.annoy_index.add_item(i, word.vector)
self.word_to_index[word.lower_] = i
self.annoy_index.build(10) # 10 trees for balance between speed and accuracy
print("Annoy Index Built.")
# Pre-compute Brown corpus frequencies
self.word_frequencies = Counter(brown.words())
self.total_words = sum(self.word_frequencies.values())
def _get_word_frequency(self, word):
"""Gets the relative frequency of a word (using Brown corpus)."""
return self.word_frequencies[word.lower()] / self.total_words
def _get_synonyms(self, word, pos):
"""Gets synonyms for a word with a given part-of-speech."""
synonyms = set()
for synset in wordnet.synsets(word, pos=self._spacy_to_wordnet_pos(pos)):
for lemma in synset.lemmas():
synonym = lemma.name().replace("_", " ")
if synonym.lower() != word.lower():
synonyms.add(synonym)
return list(synonyms)
def _get_antonyms(self, word, pos):
"""Gets antonyms for a word with a given part-of-speech."""
antonyms = set()
for synset in wordnet.synsets(word, pos=self._spacy_to_wordnet_pos(pos)):
for lemma in synset.lemmas():
for antonym in lemma.antonyms():
antonyms.add(antonym.name().replace("_", " "))
return list(antonyms)
def _get_related_words(self, word, pos, similarity_threshold=0.2, top_n=20):
"""Gets semantically related words using Annoy for fast NN search."""
try:
word_index = self.word_to_index[word.lower()]
vector = self.annoy_index.get_item_vector(word_index)
nearest_indices = self.annoy_index.get_nns_by_vector(vector, top_n * 5, include_distances=False) # Fetch extra, filter later
related_words = []
for i in nearest_indices:
neighbor_word = self.nlp.vocab[i].text
if self.nlp.vocab[i].has_vector and self.nlp.vocab[i].is_alpha and self.nlp(neighbor_word)[0].pos_ == pos:
similarity = self.nlp(word).similarity(self.nlp(neighbor_word))
if similarity >= similarity_threshold and similarity < 0.98: #Not the same word
related_words.append(neighbor_word)
if len(related_words) >= top_n:
break # Enforce top_n after filtering
return related_words
except KeyError:
return [] # Word not in vocabulary
def _spacy_to_wordnet_pos(self, spacy_pos):
"""Converts SpaCy POS tag to WordNet POS tag."""
pos_map = {
'NOUN': wordnet.NOUN,
'VERB': wordnet.VERB,
'ADJ': wordnet.ADJ,
'ADV': wordnet.ADV
}
return pos_map.get(spacy_pos)
def _mlm_perturb(self, sentence, masked_index):
"""Perturbs a sentence using the model's masked language modeling head."""
inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
mask_token_index = (inputs.input_ids[0] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[0]
# If we are masking something that is not the index we want, we do it
if (mask_token_index.numel() > 0) and (mask_token_index[0] != masked_index + 1): #+1 because of the [CLS] token
return []
# If our sentence does not have a masked token, we mask the word in the desired index
elif (mask_token_index.numel() == 0):
inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
inputs["input_ids"][0][masked_index+1] = self.tokenizer.mask_token_id
mask_token_index = (inputs.input_ids[0] == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[0]
with torch.no_grad():
outputs = self.model(**inputs)
predictions = outputs.logits[0, mask_token_index].topk(10) # get top 10 predictions
predicted_tokens = [self.tokenizer.decode(token_id) for token_id in predictions.indices.squeeze()]
return predicted_tokens
def _filter_replacements(self, word, replacements, pos, frequency_threshold_factor=2.0):
"""Filters replacements based on POS tag and frequency."""
word_freq = self._get_word_frequency(word)
min_freq = word_freq / frequency_threshold_factor
max_freq = word_freq * frequency_threshold_factor
filtered_replacements = []
for replacement in replacements:
if not replacement: #Skip empty replacements
continue
replacement_doc = self.nlp(replacement)
if not replacement_doc: # Make sure the replacement can be processed
continue
replacement_token = replacement_doc[0] # Assuming single-word replacements
if (replacement_token.pos_ == pos and
self._get_word_frequency(replacement) >= min_freq and
self._get_word_frequency(replacement) <= max_freq):
filtered_replacements.append(replacement)
return filtered_replacements
def perturb_sentence(self, sentence, perturbation_type, target_word_index=None, num_perturbations=1,
similarity_threshold=0.2, frequency_threshold_factor=2.0):
"""Perturbs a sentence with various methods and filters.
Args:
sentence: Input sentence.
perturbation_type: 'synonym', 'antonym', 'related', 'random', 'mlm'.
target_word_index: Optional index of word to perturb.
num_perturbations: Number of words to perturb (randomly chosen).
similarity_threshold: For 'related' perturbations.
frequency_threshold_factor: How much the frequency can differ.
Returns:
List of perturbed sentences.
"""
doc = self.nlp(sentence)
words = [token.text for token in doc] # Work with a list of strings
perturbed_sentences = []
valid_indices = [i for i, token in enumerate(doc)
if token.pos_ in ['NOUN', 'VERB', 'ADJ', 'ADV'] and token.is_alpha
and not token.is_stop]
if not valid_indices:
return []
if target_word_index is not None:
if target_word_index not in valid_indices:
return []
indices_to_perturb = [target_word_index]
else:
num_perturbations = min(num_perturbations, len(valid_indices))
indices_to_perturb = random.sample(valid_indices, num_perturbations)
for i in indices_to_perturb:
token = doc[i]
word = token.text
pos = token.pos_
if perturbation_type == "synonym":
replacements = self._get_synonyms(word, pos)
elif perturbation_type == "antonym":
replacements = self._get_antonyms(word, pos)
elif perturbation_type == "related":
replacements = self._get_related_words(word, pos, similarity_threshold)
elif perturbation_type == "random":
# Get words with similar frequency and the same POS
min_freq = self._get_word_frequency(word) / frequency_threshold_factor
max_freq = self._get_word_frequency(word) * frequency_threshold_factor
candidates = [w for w in self.word_frequencies if min_freq <= self._get_word_frequency(w) <= max_freq and self.nlp(w)[0].pos_ == pos]
replacements = random.sample(candidates, min(10, len(candidates))) if candidates else [] # Sample up to 10 random candidates
elif perturbation_type == "mlm":
replacements = self._mlm_perturb(sentence, i)
else:
raise ValueError("Invalid perturbation_type")
filtered_replacements = self._filter_replacements(word, replacements, pos, frequency_threshold_factor)
for replacement in filtered_replacements:
new_words = words[:] # Copy the original word list
new_words[i] = replacement
perturbed_sentences.append(" ".join(new_words))
return perturbed_sentences
def compute_semantic_similarity(self, original_sentence, perturbed_sentence):
"""Computes the semantic similarity between two sentences using Sentence-BERT."""
original_embedding = self.sbert.encode(original_sentence, convert_to_tensor=True)
perturbed_embedding = self.sbert.encode(perturbed_sentence, convert_to_tensor=True)
similarity = torch.nn.functional.cosine_similarity(original_embedding, perturbed_embedding, dim=0)
return similarity.item()