Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More consistency in file formats #3

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
**/__pycache__
models/
*.egg-info/
.ipynb_checkpoints/
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ seaborn
stanza
scorch
Jinja2
numba
tensorboard
8 changes: 4 additions & 4 deletions run_coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def write_html(dataset, name):
'sentences': [sentence]
}
for token in sentence:
if token.COREF != '_':
entities[token.COREF] = 1
if token.COREF_HEAD != '_':
entities[token.COREF_HEAD] = 1

with open(name, 'w') as f:
f.write(template.render(
Expand Down Expand Up @@ -167,12 +167,12 @@ def main(args):
# clear old info
for s in dataset:
for t in s:
t.COREF = '_'
t.COREF_HEAD = '_'

sent_index = test_graph.ndata['sent_index'][mention_idxs]
token_index = test_graph.ndata['token_index'][mention_idxs]
for s, t, m in zip(sent_index, token_index, system_clusters):
dataset[s][t].COREF = '{:d}'.format(int(m))
dataset[s][t].COREF_HEAD = '{:d}'.format(int(m))

if args.score:
# score the clustering
Expand Down
2 changes: 1 addition & 1 deletion run_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def eval(net, doc):
if args.verbose:
print(mention)
mention.refid = refid
mention.sentence[mention.head].COREF = '{}'.format(refid)
mention.sentence[mention.head].COREF_HEAD = '{}'.format(refid)
if args.verbose:
print('= - = - = - = - = - = - =')

Expand Down
6 changes: 3 additions & 3 deletions run_mentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ def write_output_conll2012(dataset, filename):
zip(sent_indices, token_indices, system):
if isMention:
if token_index in refid_lookup[sent_index]:
dataset[sent_index][token_index].COREF = \
dataset[sent_index][token_index].COREF_HEAD = \
refid_lookup[sent_index][token_index]
else:
# treat every mention as a new entity
dataset[sent_index][token_index].COREF = \
dataset[sent_index][token_index].COREF_HEAD = \
'{}'.format(entity)
entity += 1
else:
dataset[sent_index][token_index].COREF = '_'
dataset[sent_index][token_index].COREF_HEAD = '_'

if args.score:
# correct mentions:
Expand Down
14 changes: 11 additions & 3 deletions stroll/conllu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class Token():
"""A class representing a single token, ie. a word, with its annotation."""
def __init__(self, fields, isEncoded=False):
def __init__(self, fields, isEncoded=False, is_preprocessed=False):
self.isEncoded = isEncoded
if len(fields) < 10:
logging.warn(
Expand Down Expand Up @@ -51,10 +51,18 @@ def __init__(self, fields, isEncoded=False):

# Treat field 12 as co-reference info
# NOTE: this a private extension the to conllu format
self.COREF = None
self.COREF_HEAD = None
if len(fields) >= 13:
self.COREF = fields[12]
if is_preprocessed:
self.COREF_HEAD = fields[12]
else:
self.COREF = fields[12]
else:
self.COREF = '_'
if is_preprocessed:
self.COREF_HEAD = '_'
else:
self.COREF = '_'

# For coreference resolution
if len(fields) >= 14:
Expand Down
22 changes: 11 additions & 11 deletions stroll/coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def nested(self):
while token.HEAD not in visited:
token = sentence[token.HEAD]
visited.append(token.ID)
if token.COREF != '_':
if token.COREF_HEAD != '_':
return 1.0

return 0.0
Expand Down Expand Up @@ -355,7 +355,7 @@ def build_mentions_from_heads(sentence, heads):
mentions[(head, head)] = Mention(
head=head,
sentence=sentence,
refid=sentence[head].COREF,
refid=sentence[head].COREF_HEAD,
start=head,
end=head,
ids=[head],
Expand Down Expand Up @@ -410,7 +410,7 @@ def build_mentions_from_heads(sentence, heads):
mentions[(start, end)] = Mention(
head=head,
sentence=sentence,
refid=sentence[head].COREF,
refid=sentence[head].COREF_HEAD,
start=sentence[id_start].ID,
end=sentence[id_end].ID,
ids=[sentence[i].ID for i in pruned_ids],
Expand All @@ -429,7 +429,7 @@ def get_mentions(sentence):
"""
heads = []
for token in sentence:
if token.COREF != '_':
if token.COREF_HEAD != '_':
heads.append(token.ID)

return build_mentions_from_heads(sentence, heads)
Expand All @@ -447,13 +447,13 @@ def mark_gold_anaphores(dataset):
doc_rank = sentence.doc_rank
entities = {}
for token in sentence:
if token.COREF == '_':
if token.COREF_HEAD == '_':
continue
if token.COREF in entities:
entities[token.COREF] += 1
if token.COREF_HEAD in entities:
entities[token.COREF_HEAD] += 1
token.anaphore = 1.0
else:
entities[token.COREF] = 1
entities[token.COREF_HEAD] = 1
token.anaphore = 0.0


Expand Down Expand Up @@ -665,13 +665,13 @@ def preprocess_sentence(sentence):
bra_ket_mentions = get_mentions_from_bra_ket(sentence)
head_mentions = convert_mentions(bra_ket_mentions)

# clear bra-ket annotations
# clear annotations if present
for token in sentence:
token.COREF = '_'
token.COREF_HEAD = '_'

# add head based annotations
for mention in head_mentions:
sentence[mention.head].COREF = mention.refid
sentence[mention.head].COREF_HEAD = mention.refid

return bra_ket_mentions, head_mentions

Expand Down
2 changes: 1 addition & 1 deletion stroll/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __getitem__(self, index):
0).view(1, -1),
'frame': token.FRAME,
'role': token.ROLE,
'coref': token.COREF,
'coref': token.COREF_HEAD,
'sent_index': torch.tensor([index], dtype=torch.int32),
'token_index': torch.tensor(
[sentence.index(token.ID)],
Expand Down
14 changes: 14 additions & 0 deletions test/data/test.conllu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# newdoc id = jip-en-janneke
# sent_id = 1
# text = Jip liep in de tuin en hij verveelde zich zo .
1 Jip Jip PROPN N|eigen|ev|basis|zijd|stan Gender=Com|Number=Sing 2 nsubj _ _ _ _ (0)
2 liep lopen VERB WW|pv|verl|ev Number=Sing|Tense=Past|VerbForm=Fin 0 root _ _ _ _ -
3 in in ADP VZ|init _ 5 case _ _ _ _ -
4 de de DET LID|bep|stan|rest Definite=Def 5 det _ _ _ _ (1
5 tuin tuin NOUN N|soort|ev|basis|zijd|stan Gender=Com|Number=Sing 2 obl _ _ _ _ 1)
6 en en CCONJ VG|neven _ 8 cc _ _ _ _ -
7 hij hij PRON VNW|pers|pron|nomin|vol|3|ev|masc Case=Nom|Person=3|PronType=Prs 8 nsubj _ _ _ _ (0)
8 verveelde vervelen VERB WW|pv|verl|ev Number=Sing|Tense=Past|VerbForm=Fin 2 conj _ _ _ _ -
9 zich zich PRON VNW|refl|pron|obl|red|3|getal Case=Acc|Person=3|PronType=Prs|Reflex=Yes 8 expl:pv _ _ _ _ (2)
10 zo zo ADV BW _ 8 advmod _ _ _ _ -
11 . . PUNCT LET _ 2 punct _ _ _ _ -
20 changes: 20 additions & 0 deletions test/test_conllu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import stroll.conllu
import os

__here__ = os.path.dirname(os.path.realpath(__file__))


def test_empty_conll():
dataset = stroll.conllu.ConlluDataset()
assert len(dataset.sentences) == 0


def test_load_conllu():
input_file = os.path.join(__here__, 'data', 'test.conllu')
dataset = stroll.conllu.ConlluDataset(input_file)
assert len(dataset.sentences) == 1
sent = dataset[0]
assert len(sent) == 11
tok = sent[0]
assert tok.COREF == '(0)'
assert tok.COREF_HEAD == None
8 changes: 4 additions & 4 deletions utils/coref_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def transform_tree(sentence):
# transform each coordination
# ^ ^
# | deprel | deprel
# tokenA => tokenX
# tokenA => tokenX
# /conj \ conj /conj |conj \conj
# tokenB tokenC tokenA tokenB tokenC
for tokenA_id in coordinations:
Expand Down Expand Up @@ -295,7 +295,7 @@ def inspect_dataset(dataset, stats):
ref_count[refid] += 1
else:
ref_count[refid] = 1

# build chain length statistics
for refid in ref_count:
stats.chain_lengths[ref_count[refid]] += 1
Expand Down Expand Up @@ -359,10 +359,10 @@ def inspect_dataset(dataset, stats):
mentions = get_mentions(sentence)

for token in sentence:
token.COREF = '_'
token.COREF_HEAD = '_'

for mention in mentions:
sentence[mention['head']].COREF = mention['refid']
sentence[mention['head']].COREF_HEAD = mention['refid']
logging.info('Number of mentions in sentence {}'.format(len(mentions)))
mentions_in_doc += len(mentions)
tokens_in_doc += len(sentence)
Expand Down