diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..a47fb1d22 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +__pycache__ \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 000000000..5c98b4288 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml \ No newline at end of file diff --git a/.idea/ggnn.pytorch.iml b/.idea/ggnn.pytorch.iml new file mode 100644 index 000000000..7c9d48f0f --- /dev/null +++ b/.idea/ggnn.pytorch.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 000000000..105ce2da2 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/markdown-navigator b/.idea/markdown-navigator new file mode 100644 index 000000000..7f7ebda5a --- /dev/null +++ b/.idea/markdown-navigator @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/markdown-navigator.xml b/.idea/markdown-navigator.xml new file mode 100644 index 000000000..fc6322d27 --- /dev/null +++ b/.idea/markdown-navigator.xml @@ -0,0 +1,86 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 000000000..65531ca99 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 000000000..ce7d366fb --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..94a25f7f4 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/__pycache__/model_eager.cpython-36.pyc b/__pycache__/model_eager.cpython-36.pyc new file mode 100644 index 000000000..e6a2a1fc2 Binary files /dev/null and b/__pycache__/model_eager.cpython-36.pyc differ diff --git a/babi_data/extra_seq_tasks/generate_10_fold_data.sh b/babi_data/extra_seq_tasks/generate_10_fold_data.sh index 5aef612eb..01dff12e3 100755 --- a/babi_data/extra_seq_tasks/generate_10_fold_data.sh +++ b/babi_data/extra_seq_tasks/generate_10_fold_data.sh @@ -1,31 +1,31 @@ #!/bin/bash for fold in {1..10}; do - echo - echo Generating data for fold $fold - echo + echo + echo Generating data for fold $fold + echo - mkdir -p fold_$fold/noisy_data - mkdir -p fold_$fold/noisy_data/train - mkdir -p fold_$fold/noisy_data/test + mkdir -p fold_$fold/noisy_data + mkdir -p fold_$fold/noisy_data/train + mkdir -p fold_$fold/noisy_data/test - for i in 4 5; do - python generate_data.py $i 5 1000 4 > fold_$fold/noisy_data/train/${i}.txt - python generate_data.py $i 5 1000 4 > fold_$fold/noisy_data/test/${i}.txt + for i in 4 5; do + python generate_data.py $i 5 1000 4 >fold_$fold/noisy_data/train/${i}.txt + python generate_data.py $i 5 1000 4 >fold_$fold/noisy_data/test/${i}.txt - python preprocess.py $i fold_$fold/noisy_data fold_$fold/noisy_parsed - done + python preprocess.py $i fold_$fold/noisy_data fold_$fold/noisy_parsed + done - mkdir -p fold_$fold/noisy_rnn - mkdir -p fold_$fold/noisy_rnn/train - mkdir -p fold_$fold/noisy_rnn/test + mkdir -p fold_$fold/noisy_rnn + mkdir -p fold_$fold/noisy_rnn/train + mkdir -p fold_$fold/noisy_rnn/test - for i in 4 5; do - python rnn_preprocess.py fold_$fold/noisy_parsed/train/${i}_graphs.txt fold_$fold/noisy_rnn/train/${i}_tokens.txt --mode graph --nval 50 - python rnn_preprocess.py fold_$fold/noisy_rnn/train/${i}_tokens.txt fold_$fold/noisy_rnn/train/${i}_rnn.txt --mode rnn - python rnn_preprocess.py fold_$fold/noisy_rnn/train/${i}_tokens.txt.val fold_$fold/noisy_rnn/train/${i}_rnn.txt.val --mode rnn --dict fold_$fold/noisy_rnn/train/${i}_rnn.txt.dict + for i in 4 5; do + python rnn_preprocess.py fold_$fold/noisy_parsed/train/${i}_graphs.txt fold_$fold/noisy_rnn/train/${i}_tokens.txt --mode graph --nval 50 + python rnn_preprocess.py fold_$fold/noisy_rnn/train/${i}_tokens.txt fold_$fold/noisy_rnn/train/${i}_rnn.txt --mode rnn + python rnn_preprocess.py fold_$fold/noisy_rnn/train/${i}_tokens.txt.val fold_$fold/noisy_rnn/train/${i}_rnn.txt.val --mode rnn --dict fold_$fold/noisy_rnn/train/${i}_rnn.txt.dict - python rnn_preprocess.py fold_$fold/noisy_parsed/test/${i}_graphs.txt fold_$fold/noisy_rnn/test/${i}_tokens.txt --mode graph - python rnn_preprocess.py fold_$fold/noisy_rnn/test/${i}_tokens.txt fold_$fold/noisy_rnn/test/${i}_rnn.txt --dict fold_$fold/noisy_rnn/train/${i}_rnn.txt.dict --mode rnn - done + python rnn_preprocess.py fold_$fold/noisy_parsed/test/${i}_graphs.txt fold_$fold/noisy_rnn/test/${i}_tokens.txt --mode graph + python rnn_preprocess.py fold_$fold/noisy_rnn/test/${i}_tokens.txt fold_$fold/noisy_rnn/test/${i}_rnn.txt --dict fold_$fold/noisy_rnn/train/${i}_rnn.txt.dict --mode rnn + done done diff --git a/babi_data/extra_seq_tasks/generate_data.py b/babi_data/extra_seq_tasks/generate_data.py index 1af585382..d8fb902fe 100755 --- a/babi_data/extra_seq_tasks/generate_data.py +++ b/babi_data/extra_seq_tasks/generate_data.py @@ -6,6 +6,7 @@ import networkx as nx import networkx.algorithms as nxalg + class Person(object): def __init__(self, name): self.name = name @@ -16,11 +17,13 @@ def __str__(self): def gets(self, amount, display=True): if display: - print "%s gets %s" % (self.name, amount) + print + "%s gets %s" % (self.name, amount) self.purse.append(amount) def gives_everything(self, other_person): - print "%s gives-everything-to %s" % (self.name, other_person.name) + print + "%s gives-everything-to %s" % (self.name, other_person.name) for amount in self.purse: other_person.gets(amount, display=False) self.purse = [] @@ -30,8 +33,9 @@ def loses_one(self): amount_to_lose = random.choice(self.purse) self.purse.remove(amount_to_lose) - print "%s loses %s" % (self.name, amount_to_lose) - + print + "%s loses %s" % (self.name, amount_to_lose) + def make_change(person, denominations): denominations = sorted(denominations, reverse=True) @@ -74,10 +78,11 @@ def make_money_story(options): def generate_change_data(options): people = make_money_story(options) - + change_giver = random.choice(filter(lambda p: len(p.purse) > 0, people)) change = make_change(change_giver, options["denominations"]) + [""] - print "eval make-change %s\t%s" % (change_giver.name, ",".join([str(c) for c in change])) + print + "eval make-change %s\t%s" % (change_giver.name, ",".join([str(c) for c in change])) def generate_purse_data(options): @@ -85,18 +90,20 @@ def generate_purse_data(options): change_giver = random.choice(filter(lambda p: len(p.purse) > 0, people)) change = change_giver.purse + [""] - print "eval coins %s\t%s" % (change_giver.name, ",".join([str(c) for c in change])) - + print + "eval coins %s\t%s" % (change_giver.name, ",".join([str(c) for c in change])) + def generate_sorted_purse_data(options): people = make_money_story(options) change_giver = random.choice(filter(lambda p: len(p.purse) > 0, people)) change = sorted(change_giver.purse, reverse=True) + [""] - print "eval coins %s\t%s" % (change_giver.name, ",".join([str(c) for c in change])) + print + "eval coins %s\t%s" % (change_giver.name, ",".join([str(c) for c in change])) -def generate_shortest_path_data(options): +def generate_shortest_path_data(options): while True: num_nodes = options["num_entities"] g = nx.random_graphs.connected_watts_strogatz_graph(num_nodes, 3, .5) @@ -113,14 +120,16 @@ def generate_shortest_path_data(options): break for edge in g.edges(): - print "%s connected-to %s" % (edge[0], edge[1]) - print "%s connected-to %s" % (edge[1], edge[0]) + print + "%s connected-to %s" % (edge[0], edge[1]) + print + "%s connected-to %s" % (edge[1], edge[0]) - print "eval shortest-path %s %s\t%s" % (source, target, ",".join([str(v) for v in path])) - + print + "eval shortest-path %s %s\t%s" % (source, target, ",".join([str(v) for v in path])) -def generate_eulerian_circuit_data(options): +def generate_eulerian_circuit_data(options): while True: num_nodes = options["num_entities"] g = nx.random_regular_graph(2, num_nodes) @@ -135,14 +144,17 @@ def generate_eulerian_circuit_data(options): break for edge in g.edges(): - print "%s connected-to %s" % (edge[0], edge[1]) - print "%s connected-to %s" % (edge[1], edge[0]) + print + "%s connected-to %s" % (edge[0], edge[1]) + print + "%s connected-to %s" % (edge[1], edge[0]) first_edge = path[0] node_list = [str(edge[0]) for edge in path] - print "eval eulerian-circuit %s %s\t%s" % (first_edge[0], first_edge[1], - ",".join(node_list)) + print + "eval eulerian-circuit %s %s\t%s" % (first_edge[0], first_edge[1], + ",".join(node_list)) ##################### noisy data ####################### @@ -152,17 +164,19 @@ def _generate_random_node_index(n_nodes): random.shuffle(idx) return idx + def _relabel_nodes_in_edges(edges, idx): """ edges is a list of tuples """ return [(idx[e[0]], idx[e[1]]) for e in edges] + def _relabel_nodes_in_path(path, idx): return [idx[n] for n in path] -def generate_noisy_shortest_path_data(options): +def generate_noisy_shortest_path_data(options): while True: num_nodes = options["num_entities"] min_path_len = options["min_path_len"] @@ -177,7 +191,7 @@ def generate_noisy_shortest_path_data(options): path = paths[0] - if len(path) < min_path_len: continue # reject paths that's too short + if len(path) < min_path_len: continue # reject paths that's too short break @@ -198,11 +212,14 @@ def generate_noisy_shortest_path_data(options): new_path = _relabel_nodes_in_path(path, idx) for edge in new_edges: - print "%s connected-to %s" % (edge[0], edge[1]) - print "%s connected-to %s" % (edge[1], edge[0]) + print + "%s connected-to %s" % (edge[0], edge[1]) + print + "%s connected-to %s" % (edge[1], edge[0]) + + print + "eval shortest-path %s %s\t%s" % (idx[source], idx[target], ",".join([str(v) for v in new_path])) - print "eval shortest-path %s %s\t%s" % (idx[source], idx[target], ",".join([str(v) for v in new_path])) - def generate_noisy_eulerian_circuit_data(options): """ @@ -236,14 +253,18 @@ def generate_noisy_eulerian_circuit_data(options): new_path = _relabel_nodes_in_edges(path, idx) for edge in new_edges: - print "%s connected-to %s" % (edge[0], edge[1]) - print "%s connected-to %s" % (edge[1], edge[0]) + print + "%s connected-to %s" % (edge[0], edge[1]) + print + "%s connected-to %s" % (edge[1], edge[0]) first_edge = new_path[0] node_list = [str(edge[0]) for edge in new_path] - print "eval eulerian-circuit %s %s\t%s" % (first_edge[0], first_edge[1], - ",".join(node_list)) + print + "eval eulerian-circuit %s %s\t%s" % (first_edge[0], first_edge[1], + ",".join(node_list)) + def main(task, options): if task == 1: @@ -264,10 +285,10 @@ def main(task, options): generate_noisy_eulerian_circuit_data(options) - if __name__ == "__main__": if len(sys.argv) < 4: - print 'python generate_data.py []' + print + 'python generate_data.py []' else: task = int(sys.argv[1]) num_entities = int(sys.argv[2]) @@ -276,17 +297,17 @@ def main(task, options): if task <= 3: options = { - "num_entities" : num_entities, - "num_timesteps" : 20, - "denominations" : [1, 5, 10, 25] - } + "num_entities": num_entities, + "num_timesteps": 20, + "denominations": [1, 5, 10, 25] + } elif task >= 4: options = { - "num_entities" : num_entities, - "num_confusing" : num_confusing, - "min_path_len" : 3 - } + "num_entities": num_entities, + "num_confusing": num_confusing, + "min_path_len": 3 + } for i in xrange(num_examples): main(task, options) diff --git a/babi_data/extra_seq_tasks/preprocess.py b/babi_data/extra_seq_tasks/preprocess.py index 9702be60f..ebf78cfa0 100755 --- a/babi_data/extra_seq_tasks/preprocess.py +++ b/babi_data/extra_seq_tasks/preprocess.py @@ -7,6 +7,7 @@ import os import sys + def parse_dataset(file_name, edge_type_dict={}, question_type_dict={}): dataset = [] with open(file_name, 'r') as f: @@ -25,8 +26,8 @@ def parse_dataset(file_name, edge_type_dict={}, question_type_dict={}): else: question_type = question_type_dict[qtype] - args = [int(a)+1 for a in args] - targets = [int(t)+1 for t in targets.split(',')] + args = [int(a) + 1 for a in args] + targets = [int(t) + 1 for t in targets.split(',')] questions.append((question_type, args, targets)) dataset.append((edges, questions)) @@ -44,15 +45,17 @@ def parse_dataset(file_name, edge_type_dict={}, question_type_dict={}): else: edge_type = edge_type_dict[etype] - edges.append((int(src)+1, edge_type, int(tgt)+1)) + edges.append((int(src) + 1, edge_type, int(tgt) + 1)) return dataset, edge_type_dict, question_type_dict + def write_dict(d, output_file): with open(output_file, 'w') as f: - for k,v in d.iteritems(): + for k, v in d.iteritems(): f.write('%s=%s\n' % (str(k), str(v))) + def write_examples(dataset, output_file): with open(output_file, 'w') as f: for e, q in dataset: @@ -66,19 +69,21 @@ def write_examples(dataset, output_file): f.write(' %d' % a) f.write(' ') for t in q_args[2]: - f.write(' %d'% t) + f.write(' %d' % t) f.write('\n') f.write('\n') + def write_dataset(dataset, edge_type_dict, question_type_dict, output_dir, output_prefix): if not os.path.exists(output_dir): os.makedirs(output_dir) - write_dict(edge_type_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'edge_types'))) + write_dict(edge_type_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'edge_types'))) write_dict(question_type_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'question_types'))) write_examples(dataset, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'graphs'))) + if __name__ == '__main__': if len(sys.argv) < 3: print('Usage: python %s [] ' % (os.path.basename(__file__))) @@ -95,4 +100,3 @@ def write_dataset(dataset, edge_type_dict, question_type_dict, output_dir, outpu d, e, q = parse_dataset('%s/test/%s.txt' % (input_dir, question_id), e, q) write_dataset(d, e, q, '%s/test' % output_dir, question_id) - diff --git a/babi_data/extra_seq_tasks/rnn_preprocess.py b/babi_data/extra_seq_tasks/rnn_preprocess.py index 4ca9619bd..85e25e824 100755 --- a/babi_data/extra_seq_tasks/rnn_preprocess.py +++ b/babi_data/extra_seq_tasks/rnn_preprocess.py @@ -7,6 +7,7 @@ import numpy as np import argparse + def convert_graph_data(infile, outfile, n_val=0, n_train=0): data_list = [] with open(infile, 'r') as f: @@ -41,6 +42,7 @@ def convert_graph_data(infile, outfile, n_val=0, n_train=0): write_data_list_to_file(data_list[:n_train], outfile) write_data_list_to_file(data_list[-n_val:], outfile + '.val') + def write_data_list_to_file(data_list, filename): with open(filename, 'w') as f: for edges, questions in data_list: @@ -62,6 +64,7 @@ def write_data_list_to_file(data_list, filename): f.write(s_edges + s_q + '\n') + def convert_rnn_data(infile, outfile, dictfile=None): """ Convert each token in the example into an index to make processing easier. @@ -107,6 +110,7 @@ def convert_rnn_data(infile, outfile, dictfile=None): for k, v in sorted(d.items(), key=lambda t: t[0]): f.write('%s %d\n' % (k, v)) + if __name__ == '__main__': cmd_parser = argparse.ArgumentParser(description='Convert graph data into standard form for RNNs.') cmd_parser.add_argument('infile', help='path to the input file that contains all the graphs') @@ -121,4 +125,3 @@ def convert_rnn_data(infile, outfile, dictfile=None): convert_graph_data(args.infile, args.outfile, args.nval, args.ntrain) elif args.mode == 'rnn': convert_rnn_data(args.infile, args.outfile, args.dict) - diff --git a/babi_data/fix_q18.py b/babi_data/fix_q18.py index 6d195e260..fac3c5259 100755 --- a/babi_data/fix_q18.py +++ b/babi_data/fix_q18.py @@ -17,12 +17,13 @@ import random import re + def fix_file(file_path): with open(file_path) as f: lines = f.readlines() switch_prob = 0.5 - random.seed(1023) # make sure each time we run this we get the same result + random.seed(1023) # make sure each time we run this we get the same result p_eval = re.compile(r'(\d+) eval (\w+) ([><]) (\w+)(\s+)(\w+)([^\n]+)') @@ -31,7 +32,7 @@ def fix_file(file_path): m = p_eval.search(line) if m is not None: line_num, A, op, B, space, ans, others = m.groups() - if op == '>': # change all ">" to "<" + if op == '>': # change all ">" to "<" B, A = A, B op = '<' @@ -46,10 +47,10 @@ def fix_file(file_path): f.write(line) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('file_path', help='Path to the 18.txt file to be fixed.') opt = parser.parse_args() fix_file(opt.file_path) - diff --git a/babi_data/get_10_fold_data.sh b/babi_data/get_10_fold_data.sh index 77a260b1e..93d3fac01 100755 --- a/babi_data/get_10_fold_data.sh +++ b/babi_data/get_10_fold_data.sh @@ -11,51 +11,49 @@ mv lua/babi . cd .. for fold in {1..10}; do - echo ================ Generating fold $fold ===================== - echo + echo ================ Generating fold $fold ===================== + echo - cd bAbI-tasks + cd bAbI-tasks - mkdir symbolic_$fold - mkdir symbolic_$fold/train - mkdir symbolic_$fold/test + mkdir symbolic_$fold + mkdir symbolic_$fold/train + mkdir symbolic_$fold/test - echo - echo Generating 1000 training and 1000 test examples for each bAbI task. - echo This will take a while... - echo + echo + echo Generating 1000 training and 1000 test examples for each bAbI task. + echo This will take a while... + echo - # for i in `seq 1 20`; do - for i in {4,15,16,18,19}; do - ./babi-tasks $i 1000 --symbolic true > symbolic_$fold/train/${i}.txt - ./babi-tasks $i 1000 --symbolic true > symbolic_$fold/test/${i}.txt - done + # for i in `seq 1 20`; do + for i in {4,15,16,18,19}; do + ./babi-tasks $i 1000 --symbolic true >symbolic_$fold/train/${i}.txt + ./babi-tasks $i 1000 --symbolic true >symbolic_$fold/test/${i}.txt + done - # fix q18 data - python ../fix_q18.py symbolic_$fold/train/18.txt - python ../fix_q18.py symbolic_$fold/test/18.txt + # fix q18 data + python ../fix_q18.py symbolic_$fold/train/18.txt + python ../fix_q18.py symbolic_$fold/test/18.txt - # back down - cd .. + # back down + cd .. - for i in {4,15,16,18,19}; do - python symbolic_preprocess.py $i bAbI-tasks/symbolic_$fold processed_$fold - done + for i in {4,15,16,18,19}; do + python symbolic_preprocess.py $i bAbI-tasks/symbolic_$fold processed_$fold + done - # RNN data + # RNN data - mkdir processed_$fold/rnn - mkdir processed_$fold/rnn/train - mkdir processed_$fold/rnn/test + mkdir processed_$fold/rnn + mkdir processed_$fold/rnn/train + mkdir processed_$fold/rnn/test - for i in {4,15,16,18,19}; do - python rnn_preprocess.py processed_$fold/train/${i}_graphs.txt processed_$fold/rnn/train/${i}_tokens.txt --mode graph --nval 50 - python rnn_preprocess.py processed_$fold/rnn/train/${i}_tokens.txt processed_$fold/rnn/train/${i}_rnn.txt --mode rnn - python rnn_preprocess.py processed_$fold/rnn/train/${i}_tokens.txt.val processed_$fold/rnn/train/${i}_rnn.txt.val --mode rnn --dict processed_$fold/rnn/train/${i}_rnn.txt.dict + for i in {4,15,16,18,19}; do + python rnn_preprocess.py processed_$fold/train/${i}_graphs.txt processed_$fold/rnn/train/${i}_tokens.txt --mode graph --nval 50 + python rnn_preprocess.py processed_$fold/rnn/train/${i}_tokens.txt processed_$fold/rnn/train/${i}_rnn.txt --mode rnn + python rnn_preprocess.py processed_$fold/rnn/train/${i}_tokens.txt.val processed_$fold/rnn/train/${i}_rnn.txt.val --mode rnn --dict processed_$fold/rnn/train/${i}_rnn.txt.dict - python rnn_preprocess.py processed_$fold/test/${i}_graphs.txt processed_$fold/rnn/test/${i}_tokens.txt --mode graph - python rnn_preprocess.py processed_$fold/rnn/test/${i}_tokens.txt processed_$fold/rnn/test/${i}_rnn.txt --dict processed_$fold/rnn/train/${i}_rnn.txt.dict --mode rnn - done + python rnn_preprocess.py processed_$fold/test/${i}_graphs.txt processed_$fold/rnn/test/${i}_tokens.txt --mode graph + python rnn_preprocess.py processed_$fold/rnn/test/${i}_tokens.txt processed_$fold/rnn/test/${i}_rnn.txt --dict processed_$fold/rnn/train/${i}_rnn.txt.dict --mode rnn + done done - - diff --git a/babi_data/rnn_preprocess.py b/babi_data/rnn_preprocess.py index 628398221..3853b2ecb 100755 --- a/babi_data/rnn_preprocess.py +++ b/babi_data/rnn_preprocess.py @@ -7,6 +7,7 @@ import numpy as np import argparse + def convert_graph_data(infile, outfile, n_val=0, n_train=0): data_list = [] with open(infile, 'r') as f: @@ -41,6 +42,7 @@ def convert_graph_data(infile, outfile, n_val=0, n_train=0): write_data_list_to_file(data_list[:n_train], outfile) write_data_list_to_file(data_list[-n_val:], outfile + '.val') + def write_data_list_to_file(data_list, filename): with open(filename, 'w') as f: for edges, questions in data_list: @@ -62,6 +64,7 @@ def write_data_list_to_file(data_list, filename): f.write(s_edges + s_q + '\n') + def convert_rnn_data(infile, outfile, dictfile=None): """ Convert each token in the example into an index to make processing easier. @@ -107,6 +110,7 @@ def convert_rnn_data(infile, outfile, dictfile=None): for k, v in sorted(d.items(), key=lambda t: t[0]): f.write('%s %d\n' % (k, v)) + if __name__ == '__main__': cmd_parser = argparse.ArgumentParser(description='Convert graph data into standard form for RNNs.') cmd_parser.add_argument('infile', help='path to the input file that contains all the graphs') @@ -121,4 +125,3 @@ def convert_rnn_data(infile, outfile, dictfile=None): convert_graph_data(args.infile, args.outfile, args.nval, args.ntrain) elif args.mode == 'rnn': convert_rnn_data(args.infile, args.outfile, args.dict) - diff --git a/babi_data/symbolic_preprocess.py b/babi_data/symbolic_preprocess.py index 5211cd2d7..969067e70 100755 --- a/babi_data/symbolic_preprocess.py +++ b/babi_data/symbolic_preprocess.py @@ -8,6 +8,7 @@ import sys import argparse + def parse_dataset(file_name, edge_type_dict={}, node_id_dict={}, question_type_dict={}, label_dict={}): """ Parse the dataset file. @@ -86,7 +87,7 @@ def parse_dataset(file_name, edge_type_dict={}, node_id_dict={}, question_type_d qtype = tokens[3] tgt = tokens[4] label_str = tokens[5] if not tokens[5].isdigit() else None - + if label_str is not None: if label_str not in label_dict: label = len(label_dict) + 1 @@ -98,7 +99,7 @@ def parse_dataset(file_name, edge_type_dict={}, node_id_dict={}, question_type_d src_id = node_id_dict[src] tgt_id = node_id_dict[tgt] - + if qtype not in question_type_dict: question_type = len(question_type_dict) + 1 question_type_dict[qtype] = question_type @@ -108,7 +109,8 @@ def parse_dataset(file_name, edge_type_dict={}, node_id_dict={}, question_type_d if tokens[2] == 'path': questions.append([question_type, src_id, tgt_id] + labels) else: - questions.append((question_type, src_id, tgt_id, label) if label is not None else (question_type, src_id, tgt_id)) + questions.append((question_type, src_id, tgt_id, label) if label is not None else ( + question_type, src_id, tgt_id)) prev_id = line_id @@ -117,17 +119,19 @@ def parse_dataset(file_name, edge_type_dict={}, node_id_dict={}, question_type_d return dataset, edge_type_dict, node_id_dict, question_type_dict, label_dict + def write_dataset(dataset, edge_type_dict, node_id_dict, question_type_dict, label_dict, output_dir, output_prefix): if not os.path.exists(output_dir): os.makedirs(output_dir) - write_dict(edge_type_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'edge_types'))) - write_dict(node_id_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'node_ids'))) + write_dict(edge_type_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'edge_types'))) + write_dict(node_id_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'node_ids'))) write_dict(question_type_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'question_types'))) - write_dict(label_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'labels'))) + write_dict(label_dict, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'labels'))) write_examples(dataset, os.path.join(output_dir, '%s_%s.txt' % (output_prefix, 'graphs'))) + def write_examples(dataset, output_file): with open(output_file, 'w') as f: for e, q in dataset: @@ -140,16 +144,20 @@ def write_examples(dataset, output_file): f.write('\n') f.write('\n') + def write_dict(d, output_file): with open(output_file, 'w') as f: - for k,v in d.iteritems(): + for k, v in d.iteritems(): f.write('%s=%s\n' % (str(k), str(v))) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('question_id', type=int, help='ID of the question to process. We only use {4,15,16,18,19}') - parser.add_argument('input_dir', default='symbolic', help='Path to the directory that contains generated raw symbolic data, should contain two directories train and test.') - parser.add_argument('output_dir', default='processed', help='Path to the directory to store processed symbolic data.') + parser.add_argument('input_dir', default='symbolic', + help='Path to the directory that contains generated raw symbolic data, should contain two directories train and test.') + parser.add_argument('output_dir', default='processed', + help='Path to the directory to store processed symbolic data.') opt = parser.parse_args() @@ -162,4 +170,3 @@ def write_dict(d, output_file): d, e, n, q, l = parse_dataset(os.path.join(input_dir, 'test', '%s.txt' % question_id), e, n, q, l) write_dataset(d, e, n, q, l, os.path.join(output_dir, 'test'), question_id) - diff --git a/main.py b/main.py index 4395232c3..a81d0b11b 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,7 @@ if opt.cuda: torch.cuda.manual_seed_all(opt.manualSeed) + def main(opt): train_dataset = bAbIDataset(opt.dataroot, opt.question_id, True) train_dataloader = bAbIDataloader(train_dataset, batch_size=opt.batchSize, \ @@ -70,4 +71,3 @@ def main(opt): if __name__ == "__main__": main(opt) - diff --git a/main_eager.py b/main_eager.py new file mode 100644 index 000000000..ad20d9954 --- /dev/null +++ b/main_eager.py @@ -0,0 +1,60 @@ +import argparse +import random + +import tensorflow as tf + +from model_eager import GGNN +from utils.train_eager import train +from utils.test_eager import test +from utils.data.dataset import bAbIDataset + +# Set up Eager Execution +config = tf.compat.v1.ConfigProto(allow_soft_placement=True, + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1) +config.gpu_options.allow_growth = True +tf.compat.v1.enable_eager_execution(config=config) +tf.compat.v1.enable_resource_variables() + +parser = argparse.ArgumentParser() +parser.add_argument('--task_id', type=int, default=4, help='bAbI task id') +parser.add_argument('--question_id', type=int, default=0, help='question types') +parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) +parser.add_argument('--batchSize', type=int, default=10, help='input batch size') +parser.add_argument('--state_dim', type=int, default=4, help='GGNN hidden state size') +parser.add_argument('--n_steps', type=int, default=5, help='propogation steps number of GGNN') +parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for') +parser.add_argument('--lr', type=float, default=0.01, help='learning rate') +parser.add_argument('--cuda', action='store_true', help='enables cuda') +parser.add_argument('--verbal', action='store_true', help='print training info or not') +parser.add_argument('--manualSeed', type=int, help='manual seed') + +opt = parser.parse_args() +print(opt) + +if opt.manualSeed is None: + opt.manualSeed = random.randint(1, 10000) +print("Random Seed: ", opt.manualSeed) +random.seed(opt.manualSeed) +tf.compat.v1.set_random_seed(opt.manualSeed) + +opt.dataroot = 'babi_data/processed_1/train/%d_graphs.txt' % opt.task_id + + +def main(opt): + train_dataset = bAbIDataset(opt.dataroot, opt.question_id, True) + test_dataset = bAbIDataset(opt.dataroot, opt.question_id, False) + opt.annotation_dim = 1 # for bAbI + opt.n_edge_types = train_dataset.n_edge_types + opt.n_node = train_dataset.n_node + net = GGNN(opt) + criterion = tf.compat.v1.losses.softmax_cross_entropy + optimiser = tf.compat.v1.train.AdamOptimizer() + + for epoch in range(0, opt.niter): + train(epoch, train_dataset, net, criterion, optimiser, opt) + test(test_dataset, net, criterion, opt) + + +if __name__ == "__main__": + main(opt) diff --git a/model.py b/model.py index 1ad8790a2..b167984c8 100644 --- a/model.py +++ b/model.py @@ -1,12 +1,14 @@ import torch import torch.nn as nn + class AttrProxy(object): """ Translates index lookups into attribute lookups. To implement some trick which able to use list of nn.Module in a nn.Module see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2 """ + def __init__(self, module, prefix): self.module = module self.prefix = prefix @@ -20,6 +22,7 @@ class Propogator(nn.Module): Gated Propogator for GGNN Using LSTM gating mechanism """ + def __init__(self, state_dim, n_node, n_edge_types): super(Propogator, self).__init__() @@ -27,21 +30,21 @@ def __init__(self, state_dim, n_node, n_edge_types): self.n_edge_types = n_edge_types self.reset_gate = nn.Sequential( - nn.Linear(state_dim*3, state_dim), + nn.Linear(state_dim * 3, state_dim), nn.Sigmoid() ) self.update_gate = nn.Sequential( - nn.Linear(state_dim*3, state_dim), + nn.Linear(state_dim * 3, state_dim), nn.Sigmoid() ) self.tansform = nn.Sequential( - nn.Linear(state_dim*3, state_dim), + nn.Linear(state_dim * 3, state_dim), nn.Tanh() ) def forward(self, state_in, state_out, state_cur, A): - A_in = A[:, :, :self.n_node*self.n_edge_types] - A_out = A[:, :, self.n_node*self.n_edge_types:] + A_in = A[:, :, :self.n_node * self.n_edge_types] + A_out = A[:, :, self.n_node * self.n_edge_types:] a_in = torch.bmm(A_in, state_in) a_out = torch.bmm(A_out, state_out) @@ -63,10 +66,11 @@ class GGNN(nn.Module): Mode: SelectNode Implementation based on https://arxiv.org/abs/1511.05493 """ + def __init__(self, opt): super(GGNN, self).__init__() - assert (opt.state_dim >= opt.annotation_dim, \ + assert (opt.state_dim >= opt.annotation_dim, \ 'state_dim must be no less than annotation_dim') self.state_dim = opt.state_dim @@ -111,9 +115,9 @@ def forward(self, prop_state, annotation, A): in_states.append(self.in_fcs[i](prop_state)) out_states.append(self.out_fcs[i](prop_state)) in_states = torch.stack(in_states).transpose(0, 1).contiguous() - in_states = in_states.view(-1, self.n_node*self.n_edge_types, self.state_dim) + in_states = in_states.view(-1, self.n_node * self.n_edge_types, self.state_dim) out_states = torch.stack(out_states).transpose(0, 1).contiguous() - out_states = out_states.view(-1, self.n_node*self.n_edge_types, self.state_dim) + out_states = out_states.view(-1, self.n_node * self.n_edge_types, self.state_dim) prop_state = self.propogator(in_states, out_states, prop_state, A) @@ -121,5 +125,4 @@ def forward(self, prop_state, annotation, A): output = self.out(join_state) output = output.sum(2) - return output diff --git a/model_eager.py b/model_eager.py new file mode 100644 index 000000000..6ba1b0bfc --- /dev/null +++ b/model_eager.py @@ -0,0 +1,159 @@ +import tensorflow as tf + +class AttrProxy(object): + """ + Translates index lookups into attribute lookups. + To implement some trick which able to use list of nn.Module in a nn.Module + see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2 + """ + + def __init__(self, module, prefix): + self.module = module + self.prefix = prefix + + def __getitem__(self, i): + return getattr(self.module, self.prefix + str(i)) + + +class Propagator(tf.keras.Model): + """ + Gated Propogator for GGNN + Using LSTM gating mechanism + """ + + def __init__(self, state_dim, n_node, n_edge_types): + super(Propagator, self).__init__() + + self.n_node = n_node + self.n_edge_types = n_edge_types + self.reset_gate = tf.keras.layers.Dense(state_dim, activation="sigmoid") + # self.reset_gate = nn.Sequential( + # nn.Linear(state_dim * 3, state_dim), + # nn.Sigmoid() + # ) + + self.update_gate = tf.keras.layers.Dense(state_dim, activation="sigmoid") + # self.update_gate = nn.Sequential( + # nn.Linear(state_dim * 3, state_dim), + # nn.Sigmoid() + # ) + + self.transform = tf.keras.layers.Dense(state_dim) + + # self.tansform = nn.Sequential( + # nn.Linear(state_dim * 3, state_dim), + # nn.Tanh() + # ) + + @tf.contrib.eager.defun(autograph=False) + def call(self, state_in, state_out, state_cur, A): + A_in = A[:, :, :self.n_node * self.n_edge_types] + A_out = A[:, :, self.n_node * self.n_edge_types:] + + a_in = tf.compat.v1.linalg.matmul(A_in, state_in) + # a_in = torch.bmm(A_in, state_in) + + a_out = tf.compat.v1.linalg.matmul(A_out, state_out) + # a_out = torch.bmm(A_out, state_out) + + a = tf.compat.v1.concat([a_in, a_out, state_cur], axis=2) + # a = torch.cat((a_in, a_out, state_cur), 2) + + r = self.reset_gate(a) + z = self.update_gate(a) + + joined_input = tf.compat.v1.concat([a_in, a_out, r * state_cur], axis=2) + # joined_input = torch.cat((a_in, a_out, r * state_cur), 2) + + h_hat = self.transform(joined_input) + + output = (1 - z) * state_cur + z * h_hat + + return output + + +class GGNN(tf.keras.Model): + """ + Gated Graph Sequence Neural Networks (GGNN) + Mode: SelectNode + Implementation based on https://arxiv.org/abs/1511.05493 + """ + + def __init__(self, opt): + super(GGNN, self).__init__() + + assert (opt.state_dim >= opt.annotation_dim, 'state_dim must be no less than annotation_dim') + + self.state_dim = opt.state_dim + self.annotation_dim = opt.annotation_dim + self.n_edge_types = opt.n_edge_types + self.n_node = opt.n_node + self.n_steps = opt.n_steps + + for i in range(self.n_edge_types): + # incoming and outgoing edge embedding + + in_fc = tf.keras.layers.Dense(self.state_dim) + # in_fc = nn.Linear(self.state_dim, self.state_dim) + + out_fc = tf.keras.layers.Dense(self.state_dim) + # out_fc = nn.Linear(self.state_dim, self.state_dim) + + setattr(self, "in_{}".format(i), in_fc) + setattr(self, "out_{}".format(i), out_fc) + # self.add_module("in_{}".format(i), in_fc) + # self.add_module("out_{}".format(i), out_fc) + + self.in_fcs = AttrProxy(self, "in_") + self.out_fcs = AttrProxy(self, "out_") + + # Propagation Model + self.propagator = Propagator(self.state_dim, self.n_node, self.n_edge_types) + + # Output Model + self.out = tf.keras.layers.Dense(1, activation="tanh") + # self.out = nn.Sequential( + # nn.Linear(self.state_dim + self.annotation_dim, self.state_dim), + # nn.Tanh(), + # nn.Linear(self.state_dim, 1) + # ) + + # self._initialization() + # + # def _initialization(self): + # for m in self.modules(): + # if isinstance(m, nn.Linear): + # m.weight.data.normal_(0.0, 0.02) + # m.bias.data.fill_(0) + + @tf.contrib.eager.defun(autograph=False) + def call(self, prop_state, annotation, A): + for i_step in range(self.n_steps): + in_states = [] + out_states = [] + for i in range(self.n_edge_types): + in_states.append(self.in_fcs[i](prop_state)) + out_states.append(self.out_fcs[i](prop_state)) + + in_states = tf.transpose(tf.compat.v1.stack(in_states), perm=(1, 0, 2, 3)) + # in_states = torch.stack(in_states).transpose(0, 1).contiguous() + + in_states = tf.reshape(in_states, [-1, self.n_node * self.n_edge_types, self.state_dim]) + # in_states = in_states.view(-1, self.n_node * self.n_edge_types, self.state_dim) + + out_states = tf.transpose(tf.compat.v1.stack(out_states), perm=(1, 0, 2, 3)) + # out_states = torch.stack(out_states).transpose(0, 1).contiguous() + + out_states = tf.reshape(out_states, [-1, self.n_node * self.n_edge_types, self.state_dim]) + # out_states = out_states.view(-1, self.n_node * self.n_edge_types, self.state_dim) + + prop_state = self.propagator(in_states, out_states, prop_state, A) + + join_state = tf.compat.v1.concat([prop_state, annotation], axis=2) + # join_state = torch.cat((prop_state, annotation), 2) + + output = self.out(join_state) + + output = tf.compat.v1.math.reduce_sum(output, axis=2) + # output = output.sum(2) + return output diff --git a/utils/__pycache__/__init__.cpython-36.pyc b/utils/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 000000000..ff086a1fa Binary files /dev/null and b/utils/__pycache__/__init__.cpython-36.pyc differ diff --git a/utils/__pycache__/test_eager.cpython-36.pyc b/utils/__pycache__/test_eager.cpython-36.pyc new file mode 100644 index 000000000..d5eb6e6f9 Binary files /dev/null and b/utils/__pycache__/test_eager.cpython-36.pyc differ diff --git a/utils/__pycache__/train_eager.cpython-36.pyc b/utils/__pycache__/train_eager.cpython-36.pyc new file mode 100644 index 000000000..6dd91db39 Binary files /dev/null and b/utils/__pycache__/train_eager.cpython-36.pyc differ diff --git a/utils/data/__pycache__/__init__.cpython-36.pyc b/utils/data/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 000000000..a2268137c Binary files /dev/null and b/utils/data/__pycache__/__init__.cpython-36.pyc differ diff --git a/utils/data/__pycache__/dataset.cpython-36.pyc b/utils/data/__pycache__/dataset.cpython-36.pyc new file mode 100644 index 000000000..c1c3897ca Binary files /dev/null and b/utils/data/__pycache__/dataset.cpython-36.pyc differ diff --git a/utils/data/dataloader.py b/utils/data/dataloader.py index 4846a864a..db153b290 100644 --- a/utils/data/dataloader.py +++ b/utils/data/dataloader.py @@ -1,5 +1,6 @@ from torch.utils.data import DataLoader + class bAbIDataloader(DataLoader): def __init__(self, *args, **kwargs): diff --git a/utils/data/dataset.py b/utils/data/dataset.py index 7f69d2e6b..ad2aabef1 100644 --- a/utils/data/dataset.py +++ b/utils/data/dataset.py @@ -1,13 +1,14 @@ import numpy as np + def load_graphs_from_file(file_name): data_list = [] edge_list = [] target_list = [] - with open(file_name,'r') as f: + with open(file_name, 'r') as f: for line in f: if len(line.strip()) == 0: - data_list.append([edge_list,target_list]) + data_list.append([edge_list, target_list]) edge_list = [] target_list = [] else: @@ -23,6 +24,7 @@ def load_graphs_from_file(file_name): edge_list.append(digits) return data_list + def find_max_edge_id(data_list): max_edge_id = 0 for data in data_list: @@ -32,6 +34,7 @@ def find_max_edge_id(data_list): max_edge_id = item[1] return max_edge_id + def find_max_node_id(data_list): max_node_id = 0 for data in data_list: @@ -43,6 +46,7 @@ def find_max_node_id(data_list): max_node_id = item[2] return max_node_id + def find_max_task_id(data_list): max_node_id = 0 for data in data_list: @@ -52,12 +56,14 @@ def find_max_task_id(data_list): max_node_id = item[0] return max_node_id + def split_set(data_list): n_examples = len(data_list) idx = range(n_examples) train = idx[:50] val = idx[-50:] - return np.array(data_list)[train],np.array(data_list)[val] + return np.array(data_list)[train], np.array(data_list)[val] + def data_convert(data_list, n_annotation_dim): n_nodes = find_max_node_id(data_list) @@ -72,18 +78,19 @@ def data_convert(data_list, n_annotation_dim): task_type = target[0] task_output = target[-1] annotation = np.zeros([n_nodes, n_annotation_dim]) - annotation[target[1]-1][0] = 1 - task_data_list[task_type-1].append([edge_list, annotation, task_output]) + annotation[target[1] - 1][0] = 1 + task_data_list[task_type - 1].append([edge_list, annotation, task_output]) return task_data_list + def create_adjacency_matrix(edges, n_nodes, n_edge_types): a = np.zeros([n_nodes, n_nodes * n_edge_types * 2]) for edge in edges: src_idx = edge[0] e_type = edge[1] tgt_idx = edge[2] - a[tgt_idx-1][(e_type - 1) * n_nodes + src_idx - 1] = 1 - a[src_idx-1][(e_type - 1 + n_edge_types) * n_nodes + tgt_idx - 1] = 1 + a[tgt_idx - 1][(e_type - 1) * n_nodes + src_idx - 1] = 1 + a[src_idx - 1][(e_type - 1 + n_edge_types) * n_nodes + tgt_idx - 1] = 1 return a @@ -91,9 +98,10 @@ class bAbIDataset(): """ Load bAbI tasks for GGNN """ + def __init__(self, path, task_id, is_train): all_data = load_graphs_from_file(path) - self.n_edge_types = find_max_edge_id(all_data) + self.n_edge_types = find_max_edge_id(all_data) self.n_tasks = find_max_task_id(all_data) self.n_node = find_max_node_id(all_data) @@ -114,4 +122,3 @@ def __getitem__(self, index): def __len__(self): return len(self.data) - diff --git a/utils/test.py b/utils/test.py index 4ac55adf4..c735af3b3 100644 --- a/utils/test.py +++ b/utils/test.py @@ -21,7 +21,7 @@ def test(dataloader, net, criterion, optimizer, opt): output = net(init_input, annotation, adj_matrix) - test_loss += criterion(output, target).data[0] + test_loss += criterion(output, target).data.item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).cpu().sum() diff --git a/utils/test_eager.py b/utils/test_eager.py new file mode 100644 index 000000000..512b59f70 --- /dev/null +++ b/utils/test_eager.py @@ -0,0 +1,31 @@ +import tensorflow as tf +import numpy as np + + +def test(dataset, net, criterion, opt): + test_loss = 0 + correct = 0 + for i in range(int(len(dataset) / opt.batchSize)): + adj_matrix, annotation, target = list(), list(), list() + for index in range(10 * i, 10 * (i + 1), 1): + temp = dataset[index] + adj_matrix.append(temp[0]) + annotation.append(temp[1]) + target.append(temp[2]) + adj_matrix, annotation, target = np.array(adj_matrix).astype(np.float32), np.array(annotation).astype( + np.float32), np.array(target).astype(np.float32) + + padding = tf.compat.v1.zeros([len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim]) + # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double() + + init_input = tf.concat([annotation, padding], axis=2) + # init_input = torch.cat((annotation, padding), 2) + + logits = net(init_input, annotation, adj_matrix) + test_loss = criterion(onehot_labels=tf.one_hot(target, logits.shape[1]), logits=logits) + pred = tf.math.argmax(logits, dimension=-1) + correct += sum([int(a) == int(b) for a, b in zip(pred.numpy(), target)]) + + test_loss /= len(dataset) + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, len(dataset), + 100. * correct / len(dataset))) diff --git a/utils/train.py b/utils/train.py index deb0b3658..941a4375d 100644 --- a/utils/train.py +++ b/utils/train.py @@ -1,6 +1,7 @@ import torch from torch.autograd import Variable + def train(epoch, dataloader, net, criterion, optimizer, opt): net.train() for i, (adj_matrix, annotation, target) in enumerate(dataloader, 0): diff --git a/utils/train_eager.py b/utils/train_eager.py new file mode 100644 index 000000000..b68f61054 --- /dev/null +++ b/utils/train_eager.py @@ -0,0 +1,57 @@ +import tensorflow as tf +import numpy as np +import functools + +def train(epoch, dataset, net, criterion, optimizer, opt): + for i in range(int(len(dataset) / opt.batchSize)): + adj_matrix, annotation, target = list(), list(), list() + for index in range(10 * i, 10 * (i + 1), 1): + temp = dataset[index] + adj_matrix.append(temp[0]) + annotation.append(temp[1]) + target.append(temp[2]) + adj_matrix, annotation, target = np.array(adj_matrix).astype(np.float32), np.array(annotation).astype( + np.float32), np.array(target).astype(np.float32) + + padding = tf.compat.v1.zeros([len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim]) + # padding = torch.zeros(len(annotation), opt.n_node, opt.state_dim - opt.annotation_dim).double() + + init_input = tf.concat([annotation, padding], axis=2) + # init_input = torch.cat((annotation, padding), 2) + + """ + There are mainly two ways to do backprop in Eager mode + + - this one + + ```python + def model_loss(criterion, output, target): + loss = criterion(output, target) + return loss + + output = net(init_input, annotation, adj_matrix) + loss_func = functools.partial(model_loss, criterion, output, target) + optimizer.minimize(loss_func) + ``` + + - Or this one + + ```python + with tf.GradientTape() as tape: + output = net(init_input, annotation, adj_matrix) + loss = criterion(output, target) + + grads = tape.gradient(loss, net.trainable_weights) + optimizer.apply_gradients(zip(grads, net.trainable_weights)) + ``` + """ + + with tf.GradientTape() as tape: + logits = net(init_input, annotation, adj_matrix) + loss = criterion(onehot_labels=tf.one_hot(target, logits.shape[1]), logits=logits) + + grads = tape.gradient(loss, net.trainable_weights) + optimizer.apply_gradients(zip(grads, net.trainable_weights)) + + if i % int(len(dataset) / 10 + 1) == 0 and opt.verbal: + print('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(dataset), loss.data[0]))