diff --git a/dataset/preprocessing/data_processor.py b/dataset/preprocessing/data_processor.py index 156b9bb..266ac2d 100644 --- a/dataset/preprocessing/data_processor.py +++ b/dataset/preprocessing/data_processor.py @@ -17,6 +17,7 @@ from enum import Enum from textattack.augmentation import Augmenter +from textattack.transformations import WordSwapRandomCharacterInsertion from textattack.transformations.word_swaps.word_swap_neighboring_character_swap import WordSwapNeighboringCharacterSwap class MythLabels(Enum): @@ -75,11 +76,19 @@ def _load_data(self): class DataAugmentator: def __init__(self, outer_instance): self.data_processor = outer_instance - self.augmenter = Augmenter( + + self.swap_augmenter = Augmenter( transformation=WordSwapNeighboringCharacterSwap(), + pct_words_to_swap=0.4, + transformations_per_example=random.randint(1, 3) + ) + + self.augmenter_insert = Augmenter( + transformation=WordSwapRandomCharacterInsertion(), pct_words_to_swap=0.5, - transformations_per_example=3 + transformations_per_example=random.randint(1, 3) ) + def augment(self): """ @@ -115,8 +124,10 @@ def augment(self): augmented_data.append(f"{label} {name_no_spaces}") # Esto es de textattack, creo que serĂ¡ buena idea... veamos. - augmented_names = self.augmenter.augment(name) - for aug_name in augmented_names: + swapped_names = self.swap_augmenter.augment(name) + insertions_names = self.augmenter_insert.augment(name) + + for aug_name in swapped_names + insertions_names: augmented_data.append(f"{label} {aug_name}") return augmented_data