Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Necessary Changes for Persian Punctuation Prediction #18

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ def parse_arguments():
parser.add_argument('--name', default='punctuation-restore', type=str, help='name of run')
parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available')
parser.add_argument('--seed', default=1, type=int, help='random seed')
parser.add_argument('--pretrained-model', default='roberta-large', type=str, help='pretrained language model')
parser.add_argument('--pretrained-model', default='HooshvareLab/bert-base-parsbert-uncased', type=str, help='pretrained language model')
parser.add_argument('--freeze-bert', default=False, type=lambda x: (str(x).lower() == 'true'),
help='Freeze BERT layers or not')
parser.add_argument('--lstm-dim', default=-1, type=int,
help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model')
parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'),
parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'false'),
help='whether to use CRF layer or not')
parser.add_argument('--data-path', default='data/', type=str, help='path to train/dev/test datasets')
parser.add_argument('--data-path', default='punctuation-restoration/data/', type=str, help='path to train/dev/test datasets')
parser.add_argument('--language', default='english', type=str,
help='language, available options are english, bangla, english-bangla (for training with both)')
parser.add_argument('--sequence-length', default=256, type=int,
help='sequence length to use when preparing dataset (default 256)')
parser.add_argument('--augment-rate', default=0., type=float, help='token augmentation probability')
parser.add_argument('--augment-rate', default=0.15, type=float, help='token augmentation probability')
parser.add_argument('--augment-type', default='all', type=str, help='which augmentation to use')
parser.add_argument('--sub-style', default='unk', type=str, help='replacement strategy for substitution augment')
parser.add_argument('--alpha-sub', default=0.4, type=float, help='augmentation rate for substitution')
Expand Down
4 changes: 3 additions & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
}

# 'O' -> No punctuation
punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3}
punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3, 'EXCLAMATION':4, 'COLON':5}



# pretrained model name: (model class, model tokenizer, output dimension, token style)
Expand All @@ -49,4 +50,5 @@
'albert-base-v1': (AlbertModel, AlbertTokenizer, 768, 'albert'),
'albert-base-v2': (AlbertModel, AlbertTokenizer, 768, 'albert'),
'albert-large-v2': (AlbertModel, AlbertTokenizer, 1024, 'albert'),
'HooshvareLab/bert-base-parsbert-uncased': (BertModel, BertTokenizer, 768, 'bert')
}
77 changes: 43 additions & 34 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from config import *
from augmentation import *
import numpy as np
from tqdm import tqdm


def parse_data(file_path, tokenizer, sequence_len, token_style):
Expand All @@ -19,41 +20,44 @@ def parse_data(file_path, tokenizer, sequence_len, token_style):
lines = [line for line in f.read().split('\n') if line.strip()]
idx = 0
# loop until end of the entire text
while idx < len(lines):
x = [TOKEN_IDX[token_style]['START_SEQ']]
y = [0]
y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens
with tqdm(total=len(lines), desc="Parsing data", unit="line") as pbar:
while idx < len(lines):
x = [TOKEN_IDX[token_style]['START_SEQ']]
y = [0]
y_mask = [1] # which positions we need to consider while evaluating i.e., ignore pad or sub tokens

# loop until we have required sequence length
# -1 because we will have a special end of sequence token at the end
while len(x) < sequence_len - 1 and idx < len(lines):
word, punc = lines[idx].split('\t')
tokens = tokenizer.tokenize(word)
# if taking these tokens exceeds sequence length we finish current sequence with padding
# then start next sequence from this token
if len(tokens) + len(x) >= sequence_len:
break
else:
for i in range(len(tokens) - 1):
x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
y.append(0)
y_mask.append(0)
if len(tokens) > 0:
x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
# loop until we have required sequence length
# -1 because we will have a special end of sequence token at the end
while len(x) < sequence_len - 1 and idx < len(lines):
word, punc = lines[idx].split('\t')
tokens = tokenizer.tokenize(word)
# if taking these tokens exceeds sequence length we finish current sequence with padding
# then start next sequence from this token
if len(tokens) + len(x) >= sequence_len:
break
else:
x.append(TOKEN_IDX[token_style]['UNK'])
y.append(punctuation_dict[punc])
y_mask.append(1)
idx += 1
x.append(TOKEN_IDX[token_style]['END_SEQ'])
y.append(0)
y_mask.append(1)
if len(x) < sequence_len:
x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
y = y + [0 for _ in range(sequence_len - len(y))]
y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]
data_items.append([x, y, attn_mask, y_mask])
for i in range(len(tokens) - 1):
x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
y.append(0)
y_mask.append(0)
if len(tokens) > 0:
x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
else:
x.append(TOKEN_IDX[token_style]['UNK'])
y.append(punctuation_dict[punc])
y_mask.append(1)
idx += 1
pbar.update(1) # Update progress bar for each line processed
x.append(TOKEN_IDX[token_style]['END_SEQ'])
y.append(0)
y_mask.append(1)
if len(x) < sequence_len:
x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
y = y + [0 for _ in range(sequence_len - len(y))]
y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]
data_items.append([x, y, attn_mask, y_mask])
pbar.update(1) # Ensure pbar update in outer loop too
return data_items


Expand All @@ -69,11 +73,16 @@ def __init__(self, files, tokenizer, sequence_len, token_style, is_train=False,
:param augment_rate: token augmentation rate when preparing data
:param is_train: if false do not apply augmentation
"""
print('*')
if isinstance(files, list):
print('***')
self.data = []
for file in files:
# for file in files:
for file in tqdm(files, desc="Loading files"):
print('**')
self.data += parse_data(file, tokenizer, sequence_len, token_style)
else:
print('---')
self.data = parse_data(files, tokenizer, sequence_len, token_style)
self.sequence_len = sequence_len
self.augment_rate = augment_rate
Expand Down
11 changes: 6 additions & 5 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

parser = argparse.ArgumentParser(description='Punctuation restoration inference on text file')
parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available')
parser.add_argument('--pretrained-model', default='xlm-roberta-large', type=str, help='pretrained language model')
parser.add_argument('--pretrained-model', default='HooshvareLab/bert-base-parsbert-uncased', type=str, help='pretrained language model')
parser.add_argument('--lstm-dim', default=-1, type=int,
help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model')
parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'),
help='whether to use CRF layer or not')
parser.add_argument('--language', default='en', type=str, help='language English (en) oe Bangla (bn)')
parser.add_argument('--in-file', default='data/test_en.txt', type=str, help='path to inference file')
parser.add_argument('--weight-path', default='xlm-roberta-large.pt', type=str, help='model weight path')
parser.add_argument('--in-file', default='/home/ubuntu/punc/punctuation-restoration/data/test_fa.txt', type=str, help='path to inference file')
parser.add_argument('--weight-path', default='/home/ubuntu/punc/punctuation-restoration/out/weights.pt', type=str, help='model weight path')
parser.add_argument('--sequence-length', default=256, type=int,
help='sequence length to use when preparing dataset (default 256)')
parser.add_argument('--out-file', default='data/test_en_out.txt', type=str, help='output file location')
parser.add_argument('--out-file', default='/home/ubuntu/punc/punctuation-restoration/data/test_fa_out.txt', type=str, help='output file location')

args = parser.parse_args()

Expand Down Expand Up @@ -51,7 +51,8 @@ def inference():
sequence_len = args.sequence_length
result = ""
decode_idx = 0
punctuation_map = {0: '', 1: ',', 2: '.', 3: '?'}

punctuation_map = {0: '', 1: ',', 2: '.', 3: '?', 4: '!', 5: ':'}
if args.language != 'en':
punctuation_map[2] = '।'

Expand Down
14 changes: 8 additions & 6 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@

parser = argparse.ArgumentParser(description='Punctuation restoration test')
parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available')
parser.add_argument('--pretrained-model', default='roberta-large', type=str, help='pretrained language model')
parser.add_argument('--pretrained-model', default='HooshvareLab/bert-base-parsbert-uncased', type=str, help='pretrained language model')
parser.add_argument('--lstm-dim', default=-1, type=int,
help='hidden dimension in LSTM layer, if -1 is set equal to hidden dimension in language model')
parser.add_argument('--use-crf', default=False, type=lambda x: (str(x).lower() == 'true'),
help='whether to use CRF layer or not')
parser.add_argument('--data-path', default='data/test', type=str, help='path to test datasets')
parser.add_argument('--weight-path', default='out/weights.pt', type=str, help='model weight path')
parser.add_argument('--weight-path', default='out/5_weights.pt', type=str, help='model weight path')
parser.add_argument('--sequence-length', default=256, type=int,
help='sequence length to use when preparing dataset (default 256)')
parser.add_argument('--batch-size', default=8, type=int, help='batch size (default: 8)')
Expand Down Expand Up @@ -65,10 +65,10 @@ def test(data_loader):
num_iteration = 0
deep_punctuation.eval()
# +1 for overall result
tp = np.zeros(1+len(punctuation_dict), dtype=np.int)
fp = np.zeros(1+len(punctuation_dict), dtype=np.int)
fn = np.zeros(1+len(punctuation_dict), dtype=np.int)
cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int)
tp = np.zeros(1+len(punctuation_dict), dtype=np.int64)
fp = np.zeros(1+len(punctuation_dict), dtype=np.int64)
fn = np.zeros(1+len(punctuation_dict), dtype=np.int64)
cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int64)
correct = 0
total = 0
with torch.no_grad():
Expand All @@ -93,6 +93,8 @@ def test(data_loader):
# we can ignore this because we know there won't be any punctuation in this position
# since we created this position due to padding or sub-word tokenization
continue
print('i: ', i)
print('y[i]: ', y[i])
cor = y[i]
prd = y_predict[i]
if cor == prd:
Expand Down
50 changes: 35 additions & 15 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import augmentation

torch.multiprocessing.set_sharing_strategy('file_system') # https://github.com/pytorch/pytorch/issues/11201

print(1)
args = parse_arguments()

# for reproducibility
Expand All @@ -22,6 +22,7 @@
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)

print(2)
# tokenizer
tokenizer = MODELS[args.pretrained_model][1].from_pretrained(args.pretrained_model)
augmentation.tokenizer = tokenizer
Expand All @@ -34,16 +35,27 @@
aug_type = args.augment_type

# Datasets
if args.language == 'english':
train_set = Dataset(os.path.join(args.data_path, 'en/train2012'), tokenizer=tokenizer, sequence_len=sequence_len,
if args.language == 'persian':
train_set = Dataset(os.path.join(args.data_path, 'tsv_train.txt'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type)
val_set = Dataset(os.path.join(args.data_path, 'en/dev2012'), tokenizer=tokenizer, sequence_len=sequence_len,
val_set = Dataset(os.path.join(args.data_path, 'tsv_dev.txt'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=False)
test_set_ref = Dataset(os.path.join(args.data_path, 'en/test2011'), tokenizer=tokenizer, sequence_len=sequence_len,
test_set_ref = Dataset(os.path.join(args.data_path, 'tsv_test.txt'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=False)
test_set_asr = Dataset(os.path.join(args.data_path, 'en/test2011asr'), tokenizer=tokenizer, sequence_len=sequence_len,

test_set = [val_set, test_set_ref]

elif args.language == 'english':
train_set = Dataset(os.path.join(args.data_path, 'tsv_train.txt'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type)
val_set = Dataset(os.path.join(args.data_path, 'tsv_dev.txt'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=False)
test_set_ref = Dataset(os.path.join(args.data_path, 'tsv_test.txt'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=False)
test_set_asr = Dataset(os.path.join(args.data_path, 'tsv_dev3.txt'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=False)
test_set = [val_set, test_set_ref, test_set_asr]

elif args.language == 'bangla':
train_set = Dataset(os.path.join(args.data_path, 'bn/train'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type)
Expand All @@ -56,6 +68,7 @@
test_set_asr = Dataset(os.path.join(args.data_path, 'bn/test_asr'), tokenizer=tokenizer, sequence_len=sequence_len,
token_style=token_style, is_train=False)
test_set = [val_set, test_set_news, test_set_ref, test_set_asr]

elif args.language == 'english-bangla':
train_set = Dataset([os.path.join(args.data_path, 'en/train2012'), os.path.join(args.data_path, 'bn/train_bn')],
tokenizer=tokenizer, sequence_len=sequence_len, token_style=token_style, is_train=True,
Expand All @@ -78,30 +91,34 @@

# Data Loaders
data_loader_params = {
'batch_size': args.batch_size,
'batch_size': 8, #
'shuffle': True,
'num_workers': 1
}
train_loader = torch.utils.data.DataLoader(train_set, **data_loader_params)
val_loader = torch.utils.data.DataLoader(val_set, **data_loader_params)
test_loaders = [torch.utils.data.DataLoader(x, **data_loader_params) for x in test_set]
print('Data Loaded')

# logs
os.makedirs(args.save_path, exist_ok=True)
model_save_path = os.path.join(args.save_path, 'weights.pt')
log_path = os.path.join(args.save_path, args.name + '_logs.txt')


# Model
device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu')
if args.use_crf:
deep_punctuation = DeepPunctuationCRF(args.pretrained_model, freeze_bert=args.freeze_bert, lstm_dim=args.lstm_dim)
else:
deep_punctuation = DeepPunctuation(args.pretrained_model, freeze_bert=args.freeze_bert, lstm_dim=args.lstm_dim)
deep_punctuation.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(deep_punctuation.parameters(), lr=args.lr, weight_decay=args.decay)

weights = torch.tensor([0.2, 1.0, 1.5, 1.0, 1.0, 1.0], dtype=torch.float32)
criterion = nn.CrossEntropyLoss(weight=weights) # ignore_index=0

optimizer = torch.optim.Adam(deep_punctuation.parameters(), lr=args.lr, weight_decay=args.decay)
print(device)
print('Crt and Opt loaded!')

def validate(data_loader):
"""
Expand Down Expand Up @@ -142,10 +159,10 @@ def test(data_loader):
num_iteration = 0
deep_punctuation.eval()
# +1 for overall result
tp = np.zeros(1+len(punctuation_dict), dtype=np.int)
fp = np.zeros(1+len(punctuation_dict), dtype=np.int)
fn = np.zeros(1+len(punctuation_dict), dtype=np.int)
cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int)
tp = np.zeros(1+len(punctuation_dict), dtype=np.int32)
fp = np.zeros(1+len(punctuation_dict), dtype=np.int32)
fn = np.zeros(1+len(punctuation_dict), dtype=np.int32)
cm = np.zeros((len(punctuation_dict), len(punctuation_dict)), dtype=np.int32)
correct = 0
total = 0
with torch.no_grad():
Expand Down Expand Up @@ -190,6 +207,7 @@ def test(data_loader):


def train():
print('Train Started!')
with open(log_path, 'a') as f:
f.write(str(args)+'\n')
best_val_acc = 0
Expand Down Expand Up @@ -242,7 +260,9 @@ def train():
print(log)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(deep_punctuation.state_dict(), model_save_path)
torch.save(deep_punctuation.state_dict(), os.path.join(args.save_path, (str(epoch) + '_weights.pt')))

torch.save(deep_punctuation.state_dict(), model_save_path)

print('Best validation Acc:', best_val_acc)
deep_punctuation.load_state_dict(torch.load(model_save_path))
Expand Down