Skip to content

Commit

Permalink
Add a new augmentation method with Textattack: random character inser…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
geru-scotland committed Nov 9, 2024
1 parent a3d3702 commit 2f96888
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions dataset/preprocessing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from enum import Enum

from textattack.augmentation import Augmenter
from textattack.transformations import WordSwapRandomCharacterInsertion
from textattack.transformations.word_swaps.word_swap_neighboring_character_swap import WordSwapNeighboringCharacterSwap

class MythLabels(Enum):
Expand Down Expand Up @@ -75,11 +76,19 @@ def _load_data(self):
class DataAugmentator:
def __init__(self, outer_instance):
self.data_processor = outer_instance
self.augmenter = Augmenter(

self.swap_augmenter = Augmenter(
transformation=WordSwapNeighboringCharacterSwap(),
pct_words_to_swap=0.4,
transformations_per_example=random.randint(1, 3)
)

self.augmenter_insert = Augmenter(
transformation=WordSwapRandomCharacterInsertion(),
pct_words_to_swap=0.5,
transformations_per_example=3
transformations_per_example=random.randint(1, 3)
)

def augment(self):
"""
Expand Down Expand Up @@ -115,8 +124,10 @@ def augment(self):
augmented_data.append(f"{label} {name_no_spaces}")

# Esto es de textattack, creo que será buena idea... veamos.
augmented_names = self.augmenter.augment(name)
for aug_name in augmented_names:
swapped_names = self.swap_augmenter.augment(name)
insertions_names = self.augmenter_insert.augment(name)

for aug_name in swapped_names + insertions_names:
augmented_data.append(f"{label} {aug_name}")

return augmented_data
Expand Down

0 comments on commit 2f96888

Please sign in to comment.