-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added code to replicate the hypo2hyper, hypo2path and hypo2path (rev)…
… experiments.
- Loading branch information
Showing
12 changed files
with
1,213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2020 Juan Diego Rodriguez | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# keras_seq2seq | ||
Word-level seq2seq model in keras | ||
|
||
## Download the synset embeddings | ||
|
||
These are located at https://github.com/scarletcho/hypernym-path-generation/releases/tag/v1.0 Move them to the `embs/` directory to use them as in the examples below. | ||
|
||
## Hyperparameters | ||
|
||
The hyperparameters used in our paper are the default values for arguments in `wordseq2seq.py`. We used the following number of epochs for hypo2hyper rev. Nouns: 45. Verbs: 33. Instance nouns: 60. | ||
|
||
## Train the model | ||
|
||
Example use: train for 70 epochs with training source and target pairs in files `src-train.txt` and `tgt-train.txt` (and similarly for validation). | ||
|
||
``` | ||
python3 wordseq2seq.py \ | ||
--train_src hyp_data2_nodehyp/src-train.txt \ | ||
--train_tgt hyp_data2_nodehyp/tgt-train.txt \ | ||
--valid_src hyp_data2_nodehyp/src-val.txt \ | ||
--valid_tgt hyp_data2_nodehyp/tgt-val.txt \ | ||
--test_src hyp_data2_nodehyp/src-test.txt \ | ||
--test_tgt hyp_data2_nodehyp/tgt-test.txt \ | ||
--emb_file_enc embs/ft-embs-all-lower.vec \ | ||
--emb_file_dec embs/ft-embs-all-lower.vec \ | ||
--epochs 100 \ | ||
--attention \ | ||
``` | ||
|
||
Comments on the arguments: | ||
|
||
- If you *don't* want to use Luong attention, remove the `--attention` flag. | ||
- To only use the first N lines of training data: `--num_samples_train N`. To only use first N lines of validation data: `--num_samples_val N`. | ||
- If you *don't* wish to use pretrained embeddings for encoder and/or decoder, use: `--emb_file_enc None` and `--emb_file_dec None`. | ||
- Word embeddings will be frozen by default. To train them, use flag `--trainable_src_emb` and/or `--trainable_tgt_emb` | ||
- Use `--save_checkpoint_epochs N` to save a checkpoint every N epochs. | ||
- To see other command line arguments and default values, type: | ||
```python wordseq2seq.py -h``` | ||
|
||
The trained model will be saved in directory `word_models` (if it already exists it will be overwritten). | ||
|
||
## Generate translations | ||
|
||
This is an example use that translates from `src-val.txt` and saves results to `pred.txt`: | ||
|
||
``` | ||
python3 generate.py \ | ||
word_models/word_encoding.json \ | ||
word_models/word_decoding.json \ | ||
word_models/weights.60.h5 \ | ||
hyp_data2/src-val.txt \ | ||
pred.txt | ||
``` | ||
|
||
The first two arguments are the encoding and decoding json files needed to load the data. This is followed by one of the saved model .h5 files (each file contains the number of epochs in its name) and the file containing the source sentences to translate. The last argument is the name of the file to write results to. | ||
|
||
## Evaluate translations | ||
|
||
To evaluate on the validation set, for instance nouns, use the following (60 indicates the number of epochs, and 1 at the end indicates the predicted paths should be reversed): | ||
|
||
``` | ||
python3 combine.py src-valid.txt tgt-valid.txt pred-val-60.txt 60 instnouns val 1 | ||
. evaluate.sh y_results_60e_instnouns_val results_name_here | ||
``` | ||
|
||
The first argument consists of source hyponyms; the second argument consists of target gold-truth WordNet paths and the third argument is the file with predicted paths. | ||
The script will create a `.out.summary.txt` file with various scores, including H@1 and Wu&P. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import sys | ||
|
||
if __name__ == '__main__': | ||
|
||
src_file = sys.argv[1] | ||
tgt_file = sys.argv[2] | ||
pred_file = sys.argv[3] | ||
epochs = sys.argv[4] | ||
category = sys.argv[5] # verbs, nouns or instnouns | ||
split = sys.argv[6] # val, or test | ||
|
||
#optional (if used 'reversed' version, hypernym is first, not last) | ||
try: # use 1 here if want reversed. | ||
reverse = sys.argv[7] | ||
reverse = bool(reverse) | ||
except: | ||
reverse=False | ||
|
||
if category in {'verbs', 'nouns'}: | ||
hyp_name = '_hypernym' | ||
elif category == 'instnouns': | ||
hyp_name = '_instance_hypernym' | ||
else: | ||
raise ValueError("Must be 'nouns', 'verbs', or 'instnouns' ") | ||
|
||
with open(pred_file, 'r') as fd: | ||
pred = fd.readlines() | ||
|
||
with open(src_file, 'r') as fd: | ||
srcval = fd.readlines() | ||
|
||
with open(tgt_file, 'r') as fd: | ||
tgtval = fd.readlines() | ||
|
||
pred = [i.strip() for i in pred] | ||
|
||
if reverse: | ||
pred = [i.split(' ')[0] for i in pred] | ||
else: | ||
pred = [i.split(' ')[-1] for i in pred] | ||
|
||
srcval = [i.strip() for i in srcval] | ||
tgtval = [i.strip() for i in tgtval] | ||
|
||
|
||
with open('y_results_'+epochs+'e_'+category+'_'+ split +'.txt','w') as fd: | ||
for ii,i in enumerate(pred): | ||
fd.write(srcval[ii]+'\t'+hyp_name+'\t'+tgtval[ii]+'\t'+pred[ii]+'\n' ) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
import pandas as pd | ||
import re | ||
import string | ||
import numpy as np | ||
import json | ||
|
||
def get_emb_matrix(target_i2w, w2v, EMB_DIM, unk_size=0.25, oov_filename=None): | ||
emb_mat = np.zeros((len(target_i2w), EMB_DIM)) | ||
notfound = [] | ||
|
||
for i in target_i2w.keys(): | ||
if i!=0: | ||
if target_i2w[i] in w2v: | ||
emb_mat[i] = w2v[target_i2w[i]] | ||
else: | ||
notfound.append(target_i2w[i]) | ||
#print(target_i2w[i], "not there!") | ||
emb_mat[i] = np.random.uniform(-unk_size, unk_size, EMB_DIM) | ||
|
||
if oov_filename is not None: | ||
with open(oov_filename, 'w') as fd: | ||
for i in notfound: | ||
fd.write(i+'\n') | ||
|
||
return emb_mat, notfound | ||
|
||
|
||
def load_embs(embfile): | ||
with open(embfile,'r',encoding='utf8') as fd: | ||
t = fd.readlines() | ||
t = [i.strip() for i in t] | ||
t = [i.split(' ') for i in t] | ||
words = [i[0] for i in t] | ||
vecs = [i[1:] for i in t] | ||
vecs = [np.array([float(i) for i in vec]) for vec in vecs] | ||
D = dict(zip(words, vecs)) | ||
return D | ||
|
||
def write_dict_to_json(filename, D): | ||
print("Writing dictionary to "+filename) | ||
with open(filename, 'w', encoding='utf8') as f: | ||
f.write(json.dumps(D, sort_keys=True, indent=4)) | ||
|
||
def load_dict_from_json(filename): | ||
with open(filename,'r',encoding='utf8') as f: | ||
data = json.loads(f.read()) | ||
return data | ||
|
||
|
||
def load_data(filename, numlines=None, optional_processing=False):#50000): | ||
""" filename: either str (name of file; the file must be a list of pairs | ||
in the form src \t tgt) or tuple (src, tgt). | ||
If numlines is None, load all the data. | ||
If optional_processing is True: | ||
- lowercase everything | ||
- remove punctuation | ||
""" | ||
#filename = 'data-text/fra.txt' | ||
if isinstance(filename,str): | ||
lines = pd.read_table(filename, names=['src', 'tgt']) | ||
#print("Number of samples used from "+filename+": ",str(len(lines)) ) | ||
elif isinstance(filename,tuple): | ||
s,t = load_two_files(*filename) # first is src, second is tgt. | ||
lines = pd.DataFrame({'src':s, 'tgt':t}) | ||
#print("Number of samples used from "+", ".join(filename)+": ",str(len(lines)) ) | ||
else: | ||
raise ValueError("Must be either name of the file or a tuple of filenames (src, tgt).") | ||
|
||
if numlines is not None: | ||
lines = lines[:numlines] | ||
|
||
|
||
if optional_processing: | ||
lines.src=lines.src.apply(lambda x: x.lower()) | ||
lines.tgt=lines.tgt.apply(lambda x: x.lower()) | ||
#lines.src=lines.src.apply(lambda x: re.sub("'", '', x)).apply(lambda x: re.sub(",", ' COMMA', x)) | ||
#lines.tgt=lines.tgt.apply(lambda x: re.sub("'", '', x)).apply(lambda x: re.sub(",", ' COMMA', x)) | ||
exclude = set(string.punctuation) | ||
lines.src=lines.src.apply(lambda x: ''.join(ch for ch in x if ch not in exclude)) | ||
lines.tgt=lines.tgt.apply(lambda x: ''.join(ch for ch in x if ch not in exclude)) | ||
|
||
lines.tgt = lines.tgt.apply(lambda x : 'START_ '+ x + ' _END') | ||
|
||
return lines | ||
|
||
|
||
def load_two_files(src_file, tgt_file): | ||
with open(src_file,'r', encoding='utf8') as fd: | ||
src = fd.readlines() | ||
src = [i.strip() for i in src] | ||
|
||
with open(tgt_file,'r', encoding='utf8') as fd: | ||
tgt = fd.readlines() | ||
tgt = [i.strip() for i in tgt] | ||
|
||
return src, tgt | ||
|
||
def get_max_sentence_lengths(lines): | ||
max_target_sentence = max([len(i.split()) for i in lines.tgt]) | ||
max_source_sentence = max([len(i.split()) for i in lines.src]) | ||
print("Max sequence length for inputs: ", max_source_sentence) | ||
print("Max sequence length for outputs: ", max_target_sentence) | ||
return max_source_sentence, max_target_sentence | ||
|
||
|
||
def prepare_data_shared(lines): | ||
all_src_words=set() | ||
for src in lines.src: | ||
for word in src.split(): | ||
if word not in all_src_words: | ||
all_src_words.add(word) | ||
|
||
all_tgt_words=set() | ||
for tgt in lines.tgt: | ||
for word in tgt.split(): | ||
if word not in all_tgt_words: | ||
all_tgt_words.add(word) | ||
|
||
all_words = all_src_words | all_tgt_words | ||
all_words = all_words - {'START_'} | ||
#all_tgt_words = all_tgt_words - {'START_'} | ||
|
||
#input_words = sorted(list(all_src_words)) | ||
target_words = ['START_']+sorted(list(all_words)) # want 'START_' to be the first | ||
|
||
#NOTE: want the first entry (0th) to correspond to the start symbol | ||
#input_w2i = dict( | ||
# [(word, i) for i, word in enumerate(input_words)]) | ||
target_w2i = dict( | ||
[(word, i) for i, word in enumerate(target_words)]) | ||
|
||
input_w2i = target_w2i | ||
|
||
print("Target vocab size: ", len(target_w2i)) | ||
print("Source vocab size: ", len(input_w2i)) | ||
return input_w2i, target_w2i | ||
|
||
|
||
def prepare_data(lines): | ||
all_src_words=set() | ||
for src in lines.src: | ||
for word in src.split(): | ||
if word not in all_src_words: | ||
all_src_words.add(word) | ||
|
||
all_tgt_words=set() | ||
for tgt in lines.tgt: | ||
for word in tgt.split(): | ||
if word not in all_tgt_words: | ||
all_tgt_words.add(word) | ||
all_tgt_words = all_tgt_words - {'START_'} | ||
|
||
input_words = sorted(list(all_src_words)) | ||
target_words = ['START_']+sorted(list(all_tgt_words)) # want 'START_' to be the first | ||
|
||
#NOTE: want the first entry (0th) to correspond to the start symbol | ||
input_w2i = dict( | ||
[(word, i) for i, word in enumerate(input_words)]) | ||
target_w2i = dict( | ||
[(word, i) for i, word in enumerate(target_words)]) | ||
|
||
print("Source vocab size: ", len(input_w2i)) | ||
print("Target vocab size: ", len(target_w2i)) | ||
|
||
return input_w2i, target_w2i | ||
|
||
|
||
def encode_texts(input_texts, input_w2i, max_source_sentence): | ||
|
||
##max_source_sentence = max([len(i.split()) for i in input_texts]) | ||
|
||
encoder_input = np.zeros( | ||
(len(input_texts), max_source_sentence), | ||
dtype='float32') | ||
|
||
for i, input_text in enumerate(input_texts): | ||
for t, word in enumerate(input_text.split()): | ||
encoder_input[i, t] = input_w2i[word] #TODO get keyerror now.. what about OOV? | ||
|
||
return encoder_input | ||
|
||
|
||
def get_encoder_and_decoder_arrays(input_w2i, target_w2i, max_source_sentence, max_target_sentence, lines): | ||
|
||
source_i2w = dict((i, word) for word,i in input_w2i.items()) | ||
target_i2w = dict((i, word) for word,i in target_w2i.items()) | ||
|
||
num_decoder_tokens = len(target_w2i) | ||
|
||
encoder_input_data = np.zeros( | ||
(len(lines.src), max_source_sentence), | ||
dtype='float32') | ||
|
||
decoder_input_data = np.zeros( | ||
(len(lines.tgt), max_target_sentence), | ||
dtype='float32') | ||
|
||
decoder_target_data = np.zeros( | ||
(len(lines.tgt), max_target_sentence, num_decoder_tokens), | ||
dtype='float32') | ||
|
||
for i, (input_text, target_text) in enumerate(zip(lines.src, lines.tgt)): | ||
for t, word in enumerate(input_text.split()): | ||
encoder_input_data[i, t] = input_w2i[word] | ||
|
||
for t, word in enumerate(target_text.split()): | ||
decoder_input_data[i, t] = target_w2i[word] | ||
if t > 0: | ||
# Teacher forcing. | ||
# Decoder_target_data is ahead of decoder_input_data by one timestep | ||
# and will not include the start character. | ||
decoder_target_data[i, t - 1, target_w2i[word]] = 1. # probability=1 on the known word. | ||
|
||
# sanity check | ||
# def pad(x,padlen,padval): | ||
# if len(x)<padlen: | ||
# return x+[padval]*(padlen-len(x)) | ||
# else: | ||
# return x | ||
# for ii in range(len(lines.src)): | ||
# S = [source_i2w[x] for x in encoder_input_data[ii]] | ||
# vals = lines.src[ii].split() | ||
# P = pad(vals, max_source_sentence, source_i2w[0]) | ||
# if S!=P: | ||
# print('bad encoder val: ', ii) | ||
# for ii in range(len(lines.tgt)): | ||
# S = [target_i2w[x] for x in decoder_input_data[ii]] | ||
# vals = lines.tgt[ii].split() | ||
# P = pad(vals, max_target_sentence, target_i2w[0]) | ||
# if S!=P: | ||
# print('bad decoder val: ', ii) | ||
|
||
return encoder_input_data, decoder_input_data, decoder_target_data | ||
|
||
|
||
#if __name__ == "__main__": | ||
# filename = 'data-test/fra.txt' | ||
# numlines = 40000#801 | ||
# lines = load_data(filename,numlines=numlines) | ||
# input_token_index, target_token_index, max_source_sentence, max_target_sentence = prepare_data(lines) | ||
|
||
# encoder_input_data, decoder_input_data, decoder_target_data = get_encoder_and_decoder_arrays(input_token_index, target_token_index, max_source_sentence, max_target_sentence, lines) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Put ft_embs_all_lower.vec here. |
Oops, something went wrong.