-
Notifications
You must be signed in to change notification settings - Fork 2
/
pre_process.py
91 lines (76 loc) · 3.99 KB
/
pre_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import print_function
import numpy as np
import os
import random
from parse_json import data_from_json, read_write_dataset, read_write_dev_dataset
from vocab_util import create_vocabulary, initialize_vocabulary, process_glove, data_to_token_ids, create_vocab2charix_dict
import json
import argparse
#Example usage: ~/anaconda3/bin/python pre_process.py '/home/orlandom/datasets/train-v1.1.json' '/home/orlandom/datasets/dev-v1.1.json' '/home/orlandom/Downloads/glove.6B/glove.6B.300d.txt' 'data'
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Pre-process SQUAD training file')
parser.add_argument('train_path', help='JSON training file')
parser.add_argument('dev_path', help='JSON dev file')
parser.add_argument('glove_file_path', help='GLOVE embeddings file')
parser.add_argument('data_prefix', help='Folder where to save the data files')
args = parser.parse_args()
data_prefix = args.data_prefix
train_path = args.train_path
dev_path = args.dev_path
glove_file_path = args.glove_file_path
random.seed(42)
np.random.seed(42)
# data_prefix = 'data/'
# train_path = '/home/orlandom/datasets/train-v1.1.json'
# glove_file_path = '/home/orlandom/Downloads/glove.6B/glove.6B.300d.txt'
save_path = data_prefix + '/glove.trimmed.300d.npz'
vocab_path = data_prefix + '/vocab'
train_context_path = data_prefix + '/train.context'
train_question_path = data_prefix + '/train.question'
train_context_ids_path = train_context_path + ".ids"
train_question_ids_path = train_question_path + ".ids"
dev_context_path = data_prefix + '/dev.context'
dev_question_path = data_prefix + '/dev.question'
dev_context_ids_path = dev_context_path + ".ids"
dev_question_ids_path = dev_question_path + ".ids"
######################################################################################################################
# Do vocabulary and word pre-processing
#####################################################################################################################
train_data = data_from_json(train_path)
dev_data = data_from_json(dev_path)
if not os.path.exists(data_prefix):
os.makedirs(data_prefix)
train_num_questions, train_num_answers = read_write_dataset(train_data, 'train', data_prefix)
dev_num_questions, dev_num_answers = read_write_dev_dataset(dev_data, 'dev', data_prefix)
create_vocabulary(vocab_path, [train_context_path, train_question_path])
vocab, rev_vocab = initialize_vocabulary(vocab_path)
process_glove(glove_file_path, rev_vocab, save_path)
data_to_token_ids(train_context_path, train_context_ids_path, vocab_path)
data_to_token_ids(train_question_path, train_question_ids_path, vocab_path)
data_to_token_ids(dev_context_path, dev_context_ids_path, vocab_path)
data_to_token_ids(dev_question_path, dev_question_ids_path, vocab_path)
######################################################################################################################
# Do character pre-processing
#####################################################################################################################
char2ix = {'<pad>': '0', '<unk>': '1'}
ix2char = {'0': '<pad>', '1': '<unk>'}
char2ix_file = data_prefix + '/char2ix.json'
ix2char_file = data_prefix + '/ix2char.json'
vocab2charix_file = data_prefix + '/vocab2charix.json'
with open(vocab_path) as f:
i = len(char2ix)
for line in f:
line = line.strip()
if line in ['<pad>', '<sos>', '<unk>']:
continue
for char in line:
if char not in char2ix:
char2ix[char] = str(i)
ix2char[str(i)] = char
i += 1
with open(char2ix_file, 'w') as f:
json.dump(char2ix, f)
with open(ix2char_file, 'w') as f:
json.dump(ix2char, f)
create_vocab2charix_dict(vocab_path, vocab2charix_file, char2ix)