From 912e85c82a79814f1216a3c9ea7922611a74e3b7 Mon Sep 17 00:00:00 2001 From: farkoo Date: Sun, 28 Jul 2024 11:26:57 +0330 Subject: [PATCH] Adding Necessary Changes for Persian Punctuation Prediction --- src/argparser.py | 8 ++--- src/config.py | 4 ++- src/dataset.py | 77 +++++++++++++++++++++++++++--------------------- src/inference.py | 11 +++---- src/test.py | 14 +++++---- src/train.py | 50 +++++++++++++++++++++---------- 6 files changed, 99 insertions(+), 65 deletions(-) diff --git a/src/argparser.py b/src/argparser.py index d453f34..b2fb801 100644 --- a/src/argparser.py +++ b/src/argparser.py @@ -6,19 +6,19 @@ def parse_arguments(): parser.add_argument('--name', default='punctuation-restore', type=str, help='name of run') parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available') parser.add_argument('--seed', default=1, type=int, help='random seed') - parser.add_argument('--pretrained-model', default='roberta-large', type=str, help='pretrained language model') + parser.add_argument('--pretrained-model', default='HooshvareLab/bert-base-parsbert-uncased', type=str, help='pretrained language model') parser.add_argument('--freeze-bert', default=False, type=lambda x: (str(x).lower() == 'true'), help='Freeze BERT layers or not') parser.add_argument('--lstm-dim', default=-1, type=int, help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model') - parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'), + parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'false'), help='whether to use CRF layer or not') - parser.add_argument('--data-path', default='data/', type=str, help='path to train/dev/test datasets') + parser.add_argument('--data-path', default='punctuation-restoration/data/', type=str, help='path to train/dev/test datasets') parser.add_argument('--language', default='english', type=str, help='language, available options are english, bangla, english-bangla (for training with both)') parser.add_argument('--sequence-length', default=256, type=int, help='sequence length to use when preparing dataset (default 256)') - parser.add_argument('--augment-rate', default=0., type=float, help='token augmentation probability') + parser.add_argument('--augment-rate', default=0.15, type=float, help='token augmentation probability') parser.add_argument('--augment-type', default='all', type=str, help='which augmentation to use') parser.add_argument('--sub-style', default='unk', type=str, help='replacement strategy for substitution augment') parser.add_argument('--alpha-sub', default=0.4, type=float, help='augmentation rate for substitution') diff --git a/src/config.py b/src/config.py index a1800cf..a0e9d07 100644 --- a/src/config.py +++ b/src/config.py @@ -29,7 +29,8 @@ } # 'O' -> No punctuation -punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3} +punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3, 'EXCLAMATION':4, 'COLON':5} + # pretrained model name: (model class, model tokenizer, output dimension, token style) @@ -49,4 +50,5 @@ 'albert-base-v1': (AlbertModel, AlbertTokenizer, 768, 'albert'), 'albert-base-v2': (AlbertModel, AlbertTokenizer, 768, 'albert'), 'albert-large-v2': (AlbertModel, AlbertTokenizer, 1024, 'albert'), + 'HooshvareLab/bert-base-parsbert-uncased': (BertModel, BertTokenizer, 768, 'bert') } diff --git a/src/dataset.py b/src/dataset.py index 5c32a27..1337094 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -2,6 +2,7 @@ from config import * from augmentation import * import numpy as np +from tqdm import tqdm def parse_data(file_path, tokenizer, sequence_len, token_style): @@ -19,41 +20,44 @@ def parse_data(file_path, tokenizer, sequence_len, token_style): lines = [line for line in f.read().split('\n') if line.strip()] idx = 0 # loop until end of the entire text - while idx < len(lines): - x = [TOKEN_IDX[token_style]['START_SEQ']] - y = [0] - y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens + with tqdm(total=len(lines), desc="Parsing data", unit="line") as pbar: + while idx < len(lines): + x = [TOKEN_IDX[token_style]['START_SEQ']] + y = [0] + y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens - # loop until we have required sequence length - # -1 because we will have a special end of sequence token at the end - while len(x) < sequence_len - 1 and idx < len(lines): - word, punc = lines[idx].split('\t') - tokens = tokenizer.tokenize(word) - # if taking these tokens exceeds sequence length we finish current sequence with padding - # then start next sequence from this token - if len(tokens) + len(x) >= sequence_len: - break - else: - for i in range(len(tokens) - 1): - x.append(tokenizer.convert_tokens_to_ids(tokens[i])) - y.append(0) - y_mask.append(0) - if len(tokens) > 0: - x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) + # loop until we have required sequence length + # -1 because we will have a special end of sequence token at the end + while len(x) < sequence_len - 1 and idx < len(lines): + word, punc = lines[idx].split('\t') + tokens = tokenizer.tokenize(word) + # if taking these tokens exceeds sequence length we finish current sequence with padding + # then start next sequence from this token + if len(tokens) + len(x) >= sequence_len: + break else: - x.append(TOKEN_IDX[token_style]['UNK']) - y.append(punctuation_dict[punc]) - y_mask.append(1) - idx += 1 - x.append(TOKEN_IDX[token_style]['END_SEQ']) - y.append(0) - y_mask.append(1) - if len(x) < sequence_len: - x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] - y = y + [0 for _ in range(sequence_len - len(y))] - y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] - attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] - data_items.append([x, y, attn_mask, y_mask]) + for i in range(len(tokens) - 1): + x.append(tokenizer.convert_tokens_to_ids(tokens[i])) + y.append(0) + y_mask.append(0) + if len(tokens) > 0: + x.append(tokenizer.convert_tokens_to_ids(tokens[-1])) + else: + x.append(TOKEN_IDX[token_style]['UNK']) + y.append(punctuation_dict[punc]) + y_mask.append(1) + idx += 1 + pbar.update(1) # Update progress bar for each line processed + x.append(TOKEN_IDX[token_style]['END_SEQ']) + y.append(0) + y_mask.append(1) + if len(x) < sequence_len: + x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))] + y = y + [0 for _ in range(sequence_len - len(y))] + y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))] + attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x] + data_items.append([x, y, attn_mask, y_mask]) + pbar.update(1) # Ensure pbar update in outer loop too return data_items @@ -69,11 +73,16 @@ def __init__(self, files, tokenizer, sequence_len, token_style, is_train=False, :param augment_rate: token augmentation rate when preparing data :param is_train: if false do not apply augmentation """ + print('*') if isinstance(files, list): + print('***') self.data = [] - for file in files: + # for file in files: + for file in tqdm(files, desc="Loading files"): + print('**') self.data += parse_data(file, tokenizer, sequence_len, token_style) else: + print('---') self.data = parse_data(files, tokenizer, sequence_len, token_style) self.sequence_len = sequence_len self.augment_rate = augment_rate diff --git a/src/inference.py b/src/inference.py index cc40b96..8007965 100644 --- a/src/inference.py +++ b/src/inference.py @@ -7,17 +7,17 @@ parser = argparse.ArgumentParser(description='Punctuation restoration inference on text file') parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available') -parser.add_argument('--pretrained-model', default='xlm-roberta-large', type=str, help='pretrained language model') +parser.add_argument('--pretrained-model', default='HooshvareLab/bert-base-parsbert-uncased', type=str, help='pretrained language model') parser.add_argument('--lstm-dim', default=-1, type=int, help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model') parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'), help='whether to use CRF layer or not') parser.add_argument('--language', default='en', type=str, help='language English (en) oe Bangla (bn)') -parser.add_argument('--in-file', default='data/test_en.txt', type=str, help='path to inference file') -parser.add_argument('--weight-path', default='xlm-roberta-large.pt', type=str, help='model weight path') +parser.add_argument('--in-file', default='/home/ubuntu/punc/punctuation-restoration/data/test_fa.txt', type=str, help='path to inference file') +parser.add_argument('--weight-path', default='/home/ubuntu/punc/punctuation-restoration/out/weights.pt', type=str, help='model weight path') parser.add_argument('--sequence-length', default=256, type=int, help='sequence length to use when preparing dataset (default 256)') -parser.add_argument('--out-file', default='data/test_en_out.txt', type=str, help='output file location') +parser.add_argument('--out-file', default='/home/ubuntu/punc/punctuation-restoration/data/test_fa_out.txt', type=str, help='output file location') args = parser.parse_args() @@ -51,7 +51,8 @@ def inference(): sequence_len = args.sequence_length result = "" decode_idx = 0 - punctuation_map = {0: '', 1: ',', 2: '.', 3: '?'} + + punctuation_map = {0: '', 1: ',', 2: '.', 3: '?', 4: '!', 5: ':'} if args.language != 'en': punctuation_map[2] = 'ред' diff --git a/src/test.py b/src/test.py index 70cc1c2..d1ecc8c 100644 --- a/src/test.py +++ b/src/test.py @@ -11,13 +11,13 @@ parser = argparse.ArgumentParser(description='Punctuation restoration test') parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available') -parser.add_argument('--pretrained-model', default='roberta-large', type=str, help='pretrained language model') +parser.add_argument('--pretrained-model', default='HooshvareLab/bert-base-parsbert-uncased', type=str, help='pretrained language model') parser.add_argument('--lstm-dim', default=-1, type=int, help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model') parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'), help='whether to use CRF layer or not') parser.add_argument('--data-path', default='data/test', type=str, help='path to test datasets') -parser.add_argument('--weight-path', default='out/weights.pt', type=str, help='model weight path') +parser.add_argument('--weight-path', default='out/5_weights.pt', type=str, help='model weight path') parser.add_argument('--sequence-length', default=256, type=int, help='sequence length to use when preparing dataset (default 256)') parser.add_argument('--batch-size', default=8, type=int, help='batch size (default: 8)') @@ -65,10 +65,10 @@ def test(data_loader): num_iteration = 0 deep_punctuation.eval() # +1 for overall result - tp = np.zeros(1+len(punctuation_dict), dtype=np.int) - fp = np.zeros(1+len(punctuation_dict), dtype=np.int) - fn = np.zeros(1+len(punctuation_dict), dtype=np.int) - cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int) + tp = np.zeros(1+len(punctuation_dict), dtype=np.int64) + fp = np.zeros(1+len(punctuation_dict), dtype=np.int64) + fn = np.zeros(1+len(punctuation_dict), dtype=np.int64) + cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int64) correct = 0 total = 0 with torch.no_grad(): @@ -93,6 +93,8 @@ def test(data_loader): # we can ignore this because we know there won't be any punctuation in this position # since we created this position due to padding or sub-word tokenization continue + print('i: ', i) + print('y[i]: ', y[i]) cor = y[i] prd = y_predict[i] if cor == prd: diff --git a/src/train.py b/src/train.py index e05888a..2448495 100644 --- a/src/train.py +++ b/src/train.py @@ -13,7 +13,7 @@ import augmentation torch.multiprocessing.set_sharing_strategy('file_system') # https://github.com/pytorch/pytorch/issues/11201 - +print(1) args = parse_arguments() # for reproducibility @@ -22,6 +22,7 @@ torch.backends.cudnn.benchmark = False np.random.seed(args.seed) +print(2) # tokenizer tokenizer = MODELS[args.pretrained_model][1].from_pretrained(args.pretrained_model) augmentation.tokenizer = tokenizer @@ -34,16 +35,27 @@ aug_type = args.augment_type # Datasets -if args.language == 'english': - train_set = Dataset(os.path.join(args.data_path, 'en/train2012'), tokenizer=tokenizer, sequence_len=sequence_len, +if args.language == 'persian': + train_set = Dataset(os.path.join(args.data_path, 'tsv_train.txt'), tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type) - val_set = Dataset(os.path.join(args.data_path, 'en/dev2012'), tokenizer=tokenizer, sequence_len=sequence_len, + val_set = Dataset(os.path.join(args.data_path, 'tsv_dev.txt'), tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=False) - test_set_ref = Dataset(os.path.join(args.data_path, 'en/test2011'), tokenizer=tokenizer, sequence_len=sequence_len, + test_set_ref = Dataset(os.path.join(args.data_path, 'tsv_test.txt'), tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=False) - test_set_asr = Dataset(os.path.join(args.data_path, 'en/test2011asr'), tokenizer=tokenizer, sequence_len=sequence_len, + + test_set = [val_set, test_set_ref] + +elif args.language == 'english': + train_set = Dataset(os.path.join(args.data_path, 'tsv_train.txt'), tokenizer=tokenizer, sequence_len=sequence_len, + token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type) + val_set = Dataset(os.path.join(args.data_path, 'tsv_dev.txt'), tokenizer=tokenizer, sequence_len=sequence_len, + token_style=token_style, is_train=False) + test_set_ref = Dataset(os.path.join(args.data_path, 'tsv_test.txt'), tokenizer=tokenizer, sequence_len=sequence_len, + token_style=token_style, is_train=False) + test_set_asr = Dataset(os.path.join(args.data_path, 'tsv_dev3.txt'), tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=False) test_set = [val_set, test_set_ref, test_set_asr] + elif args.language == 'bangla': train_set = Dataset(os.path.join(args.data_path, 'bn/train'), tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type) @@ -56,6 +68,7 @@ test_set_asr = Dataset(os.path.join(args.data_path, 'bn/test_asr'), tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=False) test_set = [val_set, test_set_news, test_set_ref, test_set_asr] + elif args.language == 'english-bangla': train_set = Dataset([os.path.join(args.data_path, 'en/train2012'), os.path.join(args.data_path, 'bn/train_bn')], tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=True, @@ -78,20 +91,20 @@ # Data Loaders data_loader_params = { - 'batch_size': args.batch_size, + 'batch_size': 8, # 'shuffle': True, 'num_workers': 1 } train_loader = torch.utils.data.DataLoader(train_set, **data_loader_params) val_loader = torch.utils.data.DataLoader(val_set, **data_loader_params) test_loaders = [torch.utils.data.DataLoader(x, **data_loader_params) for x in test_set] +print('Data Loaded') # logs os.makedirs(args.save_path, exist_ok=True) model_save_path = os.path.join(args.save_path, 'weights.pt') log_path = os.path.join(args.save_path, args.name + '_logs.txt') - # Model device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu') if args.use_crf: @@ -99,9 +112,13 @@ else: deep_punctuation = DeepPunctuation(args.pretrained_model, freeze_bert=args.freeze_bert, lstm_dim=args.lstm_dim) deep_punctuation.to(device) -criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(deep_punctuation.parameters(), lr=args.lr, weight_decay=args.decay) +weights = torch.tensor([0.2, 1.0, 1.5, 1.0, 1.0, 1.0], dtype=torch.float32) +criterion = nn.CrossEntropyLoss(weight=weights) # ignore_index=0 + +optimizer = torch.optim.Adam(deep_punctuation.parameters(), lr=args.lr, weight_decay=args.decay) +print(device) +print('Crt and Opt loaded!') def validate(data_loader): """ @@ -142,10 +159,10 @@ def test(data_loader): num_iteration = 0 deep_punctuation.eval() # +1 for overall result - tp = np.zeros(1+len(punctuation_dict), dtype=np.int) - fp = np.zeros(1+len(punctuation_dict), dtype=np.int) - fn = np.zeros(1+len(punctuation_dict), dtype=np.int) - cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int) + tp = np.zeros(1+len(punctuation_dict), dtype=np.int32) + fp = np.zeros(1+len(punctuation_dict), dtype=np.int32) + fn = np.zeros(1+len(punctuation_dict), dtype=np.int32) + cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int32) correct = 0 total = 0 with torch.no_grad(): @@ -190,6 +207,7 @@ def test(data_loader): def train(): + print('Train Started!') with open(log_path, 'a') as f: f.write(str(args)+'\n') best_val_acc = 0 @@ -242,7 +260,9 @@ def train(): print(log) if val_acc > best_val_acc: best_val_acc = val_acc - torch.save(deep_punctuation.state_dict(), model_save_path) + torch.save(deep_punctuation.state_dict(), os.path.join(args.save_path, (str(epoch) + '_weights.pt'))) + + torch.save(deep_punctuation.state_dict(), model_save_path) print('Best validation Acc:', best_val_acc) deep_punctuation.load_state_dict(torch.load(model_save_path))