Skip to content

Commit

Permalink
Added batch size for ner
Browse files Browse the repository at this point in the history
  • Loading branch information
fexfl committed Dec 10, 2024
1 parent 446f34c commit 5c30c54
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions mailcom/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

class Pseudonymize:
def __init__(self):
# amount of sentences passed to transformers ner_classification
# -1 corresponds to all sentences
self.n_batch_sentences = 1

self.spacy_default_model_dict = {
"es": "es_core_news_md",
Expand Down Expand Up @@ -58,6 +61,13 @@ def __init__(self):
# records the already replaced names in an email
self.used_first_names = {}

def set_sentence_batch_size(self, batch_size: int):
if batch_size == 0 or batch_size < -1:
raise ValueError(
"Batch size should either be a positive integer or -1 for all sentences."
)
self.n_batch_sentences = batch_size

def init_spacy(self, language: str, model="default"):
if model == "default":
model = self.spacy_default_model_dict[language]
Expand Down Expand Up @@ -201,14 +211,21 @@ def concatenate(self, sentences):
def pseudonymize(self, text: str):
self.reset()
sentences = self.get_sentences(text)
pseudonymized_sentences = []
for sent in sentences:
sent = self.pseudonymize_email_addresses(sent)
ner = self.get_ner(sent)
ps_sent = " ".join(self.pseudonymize_ne(ner, sent)) if ner else sent
batches = [
sentences[n : n + self.n_batch_sentences] # noqa
for n in range(0, len(sentences), self.n_batch_sentences)
]
pseudonymized_batches = []
for batch in batches:
print(batch)
batch = self.concatenate(batch)
print(batch)
batch = self.pseudonymize_email_addresses(batch)
ner = self.get_ner(batch)
ps_sent = " ".join(self.pseudonymize_ne(ner, batch)) if ner else batch
ps_sent = self.pseudonymize_numbers(ps_sent)
pseudonymized_sentences.append(ps_sent)
return self.concatenate(pseudonymized_sentences)
pseudonymized_batches.append(ps_sent)
return self.concatenate(pseudonymized_batches)


def check_dir(path: str) -> bool:
Expand Down Expand Up @@ -242,6 +259,7 @@ def make_dir(path: str):
pseudonymizer = Pseudonymize()
pseudonymizer.init_spacy("fr")
pseudonymizer.init_transformers()
pseudonymizer.set_sentence_batch_size(1000)
for file in io.email_list:
print("Parsing input file {}".format(file))
text = io.get_text(file)
Expand Down

0 comments on commit 5c30c54

Please sign in to comment.