Skip to content

Commit

Permalink
add: chuliu_edmonds_one_root_with_constrains
Browse files Browse the repository at this point in the history
  • Loading branch information
kirianguiller committed Apr 5, 2023
1 parent 80918f1 commit f1f14d3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 7 deletions.
18 changes: 15 additions & 3 deletions BertForDeprel/parser/utils/chuliu_edmonds_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Tuple
import numpy as np
from numpy.typing import NDArray

def tarjan(tree):
""""""
Expand Down Expand Up @@ -45,7 +47,6 @@ def strong_connect(i):

def chuliu_edmonds(scores):
""""""

np.fill_diagonal(scores, -float('inf')) # prevent self-loops
scores[0] = -float('inf')
scores[0,0] = 0
Expand Down Expand Up @@ -126,7 +127,7 @@ def chuliu_edmonds(scores):
return new_tree

#===============================================================
def chuliu_edmonds_one_root(scores):
def chuliu_edmonds_one_root(scores: NDArray):
""""""

scores = scores.astype(np.float64)
Expand Down Expand Up @@ -161,4 +162,15 @@ def set_root(scores, root):
f.write('{}: {}, {}\n'.format(tree, scores, roots_to_try))
f.write('{}: {}, {}, {}\n'.format(_tree, _scores, tree_probs, tree_score))
raise
return best_tree
return best_tree


def chuliu_edmonds_one_root_with_constrains(scores: NDArray, forced_relations: List[Tuple] = []):
"""
forced_relations: List[Tuple] : List of (i, j) tuples, the i-eme index will be forced to be dependant of j-eme token
"""
if len(forced_relations):
scores = scores.copy()
for forced_relation in forced_relations:
scores[forced_relation[0], forced_relation[1]] += 1000
return chuliu_edmonds_one_root(scores)
26 changes: 26 additions & 0 deletions tests/chuliu_edmonds_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from BertForDeprel.parser.utils.chuliu_edmonds_utils import chuliu_edmonds, chuliu_edmonds_one_root, chuliu_edmonds_one_root_with_constrains
import numpy as np


mock_coefs_lists = [
[-9.9, -6.6, -1.2, -6.5, -8.7, -8.3],
[-14.9, -13.2, 0.5, -13, -9.4, -7.9],
[1.6, -3.8, -7.4, -3.9, -14.9, -20.1],
[-1.8, -4.3, -0.9, -2.6, -6.1, -7.7],
[-9.7, -8.5, -8.7, -1.0, -5.7, -8.8],
[-7.9, -4.9, -4.5, -1.2, -3.0, -4.1]
]

mock_coefs_array = np.array(mock_coefs_lists).astype(np.float64)

def test_chuliu_edmonds_one_root():
assert chuliu_edmonds_one_root(mock_coefs_array).tolist() == [0, 2, 0, 2, 3, 3]

def test_chuliu_edmonds_one_root_with_constrains():
assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array).tolist() == [0, 2, 0, 2, 3, 3]
assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, []).tolist() == [0, 2, 0, 2, 3, 3]
assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(4, 5)]).tolist() == [0, 2, 0, 2, 5, 3]
assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(1, 0)]).tolist() == [0, 0, 1, 2, 3, 3]
assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(3, 0)]).tolist() == [0, 2, 3, 0, 3, 3]
assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(4, 1), (4, 5)]).tolist() == [0, 2, 0, 2, 1, 3]
assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(5, 1), (5, 4)]).tolist() == [0, 2, 0, 2, 3, 4]
8 changes: 4 additions & 4 deletions tests/load_data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ def test_train_output():
assert dataset[0]["idx"] == 0
assert dataset[0]["uposs"] == [-1, 6, 2, 3, 0, 4, 1, 7, 1, 7, -1, 1, 8, 3, 4]
assert dataset[0]["heads"] == [-1, 2, 0, 5, 5, 2, 5, 6, 5, 8, -1, 5, 11, 14, 12]
assert dataset[0]["deprels"] == [-1, 1, 8, 3, 5, 9, 10, 0, 10, 0, -1, 6, 0, 3, 0]
assert dataset[0]["deprels"] == [-1, 2, 8, 4, 6, 9, 10, 1, 10, 1, -1, 7, 1, 4, 1]


def test_collate_fn():
dataset = ConlluDataset(PATH_TEST_CONLLU, model_params_test, "train", compute_annotation_schema_if_not_found=True)
batch = dataset.collate_fn([dataset[0], dataset[1]])
assert torch.equal(batch["deprels"], torch.tensor([[-1, 1, 8, 3, 5, 9, 10, 0, 10, 0, -1, 6, 0, 3, 0, -1],
[-1, 1, 8, 2, 9, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]))
assert torch.equal(batch["deprels"], torch.tensor([[-1, 2, 8, 4, 6, 9, 10, 1, 10, 1, -1, 7, 1, 4, 1, -1],
[-1, 2, 8, 3, 9, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]))

def test_add_prediction_to_sentence_json():
dataset = ConlluDataset(PATH_TEST_CONLLU, model_params_test, "train", compute_annotation_schema_if_not_found=True)
predicted_sentence_json = dataset.add_prediction_to_sentence_json(0, [2,3,4,2,5], [1,2,4,15,4], [5,2,3,4,3])
predicted_sentence_json = dataset.add_prediction_to_sentence_json(0, [2,3,4,2,5], [0,0,0,0,0], [1,2,4,15,4], [5,2,3,4,3], [2,3,4,2,5], [5,2,3,4,3])
assert predicted_sentence_json["treeJson"]["nodesJson"]["4"]["HEAD"] == 15

0 comments on commit f1f14d3

Please sign in to comment.