From c357ac83891e134f544732c2875aaea781a3f777 Mon Sep 17 00:00:00 2001 From: Dafne van Kuppevelt Date: Wed, 23 Sep 2020 17:01:21 +0200 Subject: [PATCH 1/2] Added simple test --- requirements.txt | 2 ++ test/data/test.conllu | 14 ++++++++++++++ test/test_conllu.py | 19 +++++++++++++++++++ 3 files changed, 35 insertions(+) create mode 100644 test/data/test.conllu create mode 100644 test/test_conllu.py diff --git a/requirements.txt b/requirements.txt index 170f835..484f773 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,5 @@ seaborn stanza scorch Jinja2 +numba +tensorboard diff --git a/test/data/test.conllu b/test/data/test.conllu new file mode 100644 index 0000000..fb78ab3 --- /dev/null +++ b/test/data/test.conllu @@ -0,0 +1,14 @@ +# newdoc id = jip-en-janneke +# sent_id = 1 +# text = Jip liep in de tuin en hij verveelde zich zo . +1 Jip Jip PROPN N|eigen|ev|basis|zijd|stan Gender=Com|Number=Sing 2 nsubj _ _ _ _ (0) +2 liep lopen VERB WW|pv|verl|ev Number=Sing|Tense=Past|VerbForm=Fin 0 root _ _ _ _ - +3 in in ADP VZ|init _ 5 case _ _ _ _ - +4 de de DET LID|bep|stan|rest Definite=Def 5 det _ _ _ _ (1 +5 tuin tuin NOUN N|soort|ev|basis|zijd|stan Gender=Com|Number=Sing 2 obl _ _ _ _ 1) +6 en en CCONJ VG|neven _ 8 cc _ _ _ _ - +7 hij hij PRON VNW|pers|pron|nomin|vol|3|ev|masc Case=Nom|Person=3|PronType=Prs 8 nsubj _ _ _ _ (0) +8 verveelde vervelen VERB WW|pv|verl|ev Number=Sing|Tense=Past|VerbForm=Fin 2 conj _ _ _ _ - +9 zich zich PRON VNW|refl|pron|obl|red|3|getal Case=Acc|Person=3|PronType=Prs|Reflex=Yes 8 expl:pv _ _ _ _ (2) +10 zo zo ADV BW _ 8 advmod _ _ _ _ - +11 . . PUNCT LET _ 2 punct _ _ _ _ - diff --git a/test/test_conllu.py b/test/test_conllu.py new file mode 100644 index 0000000..cff86fe --- /dev/null +++ b/test/test_conllu.py @@ -0,0 +1,19 @@ +import stroll.conllu +import os + +__here__ = os.path.dirname(os.path.realpath(__file__)) + + +def test_empty_conll(): + dataset = stroll.conllu.ConlluDataset() + assert len(dataset.sentences) == 0 + + +def test_load_conllu(): + input_file = os.path.join(__here__, 'data', 'test.conllu') + dataset = stroll.conllu.ConlluDataset(input_file) + assert len(dataset.sentences) == 1 + sent = dataset[0] + assert len(sent) == 11 + tok = sent[0] + tok.COREF = '(0)' From cb13d2d2abae39ce024c5fd3ccaf449d322efd83 Mon Sep 17 00:00:00 2001 From: Dafne van Kuppevelt Date: Wed, 23 Sep 2020 19:41:47 +0200 Subject: [PATCH 2/2] Make special field for head based coref --- .gitignore | 2 ++ run_coref.py | 8 ++++---- run_entity.py | 2 +- run_mentions.py | 6 +++--- stroll/conllu.py | 14 +++++++++++--- stroll/coref.py | 22 +++++++++++----------- stroll/graph.py | 2 +- test/test_conllu.py | 3 ++- utils/coref_check.py | 8 ++++---- 9 files changed, 39 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index d5c8e40..184ac22 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ **/__pycache__ models/ +*.egg-info/ +.ipynb_checkpoints/ diff --git a/run_coref.py b/run_coref.py index abc4d2e..527efad 100644 --- a/run_coref.py +++ b/run_coref.py @@ -78,8 +78,8 @@ def write_html(dataset, name): 'sentences': [sentence] } for token in sentence: - if token.COREF != '_': - entities[token.COREF] = 1 + if token.COREF_HEAD != '_': + entities[token.COREF_HEAD] = 1 with open(name, 'w') as f: f.write(template.render( @@ -167,12 +167,12 @@ def main(args): # clear old info for s in dataset: for t in s: - t.COREF = '_' + t.COREF_HEAD = '_' sent_index = test_graph.ndata['sent_index'][mention_idxs] token_index = test_graph.ndata['token_index'][mention_idxs] for s, t, m in zip(sent_index, token_index, system_clusters): - dataset[s][t].COREF = '{:d}'.format(int(m)) + dataset[s][t].COREF_HEAD = '{:d}'.format(int(m)) if args.score: # score the clustering diff --git a/run_entity.py b/run_entity.py index f3b476d..be7e29b 100644 --- a/run_entity.py +++ b/run_entity.py @@ -196,7 +196,7 @@ def eval(net, doc): if args.verbose: print(mention) mention.refid = refid - mention.sentence[mention.head].COREF = '{}'.format(refid) + mention.sentence[mention.head].COREF_HEAD = '{}'.format(refid) if args.verbose: print('= - = - = - = - = - = - =') diff --git a/run_mentions.py b/run_mentions.py index 36b851b..57450e7 100644 --- a/run_mentions.py +++ b/run_mentions.py @@ -170,15 +170,15 @@ def write_output_conll2012(dataset, filename): zip(sent_indices, token_indices, system): if isMention: if token_index in refid_lookup[sent_index]: - dataset[sent_index][token_index].COREF = \ + dataset[sent_index][token_index].COREF_HEAD = \ refid_lookup[sent_index][token_index] else: # treat every mention as a new entity - dataset[sent_index][token_index].COREF = \ + dataset[sent_index][token_index].COREF_HEAD = \ '{}'.format(entity) entity += 1 else: - dataset[sent_index][token_index].COREF = '_' + dataset[sent_index][token_index].COREF_HEAD = '_' if args.score: # correct mentions: diff --git a/stroll/conllu.py b/stroll/conllu.py index 3d19449..60a38e7 100644 --- a/stroll/conllu.py +++ b/stroll/conllu.py @@ -18,7 +18,7 @@ class Token(): """A class representing a single token, ie. a word, with its annotation.""" - def __init__(self, fields, isEncoded=False): + def __init__(self, fields, isEncoded=False, is_preprocessed=False): self.isEncoded = isEncoded if len(fields) < 10: logging.warn( @@ -51,10 +51,18 @@ def __init__(self, fields, isEncoded=False): # Treat field 12 as co-reference info # NOTE: this a private extension the to conllu format + self.COREF = None + self.COREF_HEAD = None if len(fields) >= 13: - self.COREF = fields[12] + if is_preprocessed: + self.COREF_HEAD = fields[12] + else: + self.COREF = fields[12] else: - self.COREF = '_' + if is_preprocessed: + self.COREF_HEAD = '_' + else: + self.COREF = '_' # For coreference resolution if len(fields) >= 14: diff --git a/stroll/coref.py b/stroll/coref.py index ffffd52..9b8704a 100644 --- a/stroll/coref.py +++ b/stroll/coref.py @@ -144,7 +144,7 @@ def nested(self): while token.HEAD not in visited: token = sentence[token.HEAD] visited.append(token.ID) - if token.COREF != '_': + if token.COREF_HEAD != '_': return 1.0 return 0.0 @@ -355,7 +355,7 @@ def build_mentions_from_heads(sentence, heads): mentions[(head, head)] = Mention( head=head, sentence=sentence, - refid=sentence[head].COREF, + refid=sentence[head].COREF_HEAD, start=head, end=head, ids=[head], @@ -410,7 +410,7 @@ def build_mentions_from_heads(sentence, heads): mentions[(start, end)] = Mention( head=head, sentence=sentence, - refid=sentence[head].COREF, + refid=sentence[head].COREF_HEAD, start=sentence[id_start].ID, end=sentence[id_end].ID, ids=[sentence[i].ID for i in pruned_ids], @@ -429,7 +429,7 @@ def get_mentions(sentence): """ heads = [] for token in sentence: - if token.COREF != '_': + if token.COREF_HEAD != '_': heads.append(token.ID) return build_mentions_from_heads(sentence, heads) @@ -447,13 +447,13 @@ def mark_gold_anaphores(dataset): doc_rank = sentence.doc_rank entities = {} for token in sentence: - if token.COREF == '_': + if token.COREF_HEAD == '_': continue - if token.COREF in entities: - entities[token.COREF] += 1 + if token.COREF_HEAD in entities: + entities[token.COREF_HEAD] += 1 token.anaphore = 1.0 else: - entities[token.COREF] = 1 + entities[token.COREF_HEAD] = 1 token.anaphore = 0.0 @@ -665,13 +665,13 @@ def preprocess_sentence(sentence): bra_ket_mentions = get_mentions_from_bra_ket(sentence) head_mentions = convert_mentions(bra_ket_mentions) - # clear bra-ket annotations + # clear annotations if present for token in sentence: - token.COREF = '_' + token.COREF_HEAD = '_' # add head based annotations for mention in head_mentions: - sentence[mention.head].COREF = mention.refid + sentence[mention.head].COREF_HEAD = mention.refid return bra_ket_mentions, head_mentions diff --git a/stroll/graph.py b/stroll/graph.py index f253efc..10580db 100644 --- a/stroll/graph.py +++ b/stroll/graph.py @@ -68,7 +68,7 @@ def __getitem__(self, index): 0).view(1, -1), 'frame': token.FRAME, 'role': token.ROLE, - 'coref': token.COREF, + 'coref': token.COREF_HEAD, 'sent_index': torch.tensor([index], dtype=torch.int32), 'token_index': torch.tensor( [sentence.index(token.ID)], diff --git a/test/test_conllu.py b/test/test_conllu.py index cff86fe..3d17006 100644 --- a/test/test_conllu.py +++ b/test/test_conllu.py @@ -16,4 +16,5 @@ def test_load_conllu(): sent = dataset[0] assert len(sent) == 11 tok = sent[0] - tok.COREF = '(0)' + assert tok.COREF == '(0)' + assert tok.COREF_HEAD == None diff --git a/utils/coref_check.py b/utils/coref_check.py index 654bb13..4333b8e 100644 --- a/utils/coref_check.py +++ b/utils/coref_check.py @@ -79,7 +79,7 @@ def transform_tree(sentence): # transform each coordination # ^ ^ # | deprel | deprel - # tokenA => tokenX + # tokenA => tokenX # /conj \ conj /conj |conj \conj # tokenB tokenC tokenA tokenB tokenC for tokenA_id in coordinations: @@ -295,7 +295,7 @@ def inspect_dataset(dataset, stats): ref_count[refid] += 1 else: ref_count[refid] = 1 - + # build chain length statistics for refid in ref_count: stats.chain_lengths[ref_count[refid]] += 1 @@ -359,10 +359,10 @@ def inspect_dataset(dataset, stats): mentions = get_mentions(sentence) for token in sentence: - token.COREF = '_' + token.COREF_HEAD = '_' for mention in mentions: - sentence[mention['head']].COREF = mention['refid'] + sentence[mention['head']].COREF_HEAD = mention['refid'] logging.info('Number of mentions in sentence {}'.format(len(mentions))) mentions_in_doc += len(mentions) tokens_in_doc += len(sentence)