From 75885f3c73bb8e38279923bb7122f44671c9c192 Mon Sep 17 00:00:00 2001 From: kirianguiller Date: Wed, 5 Apr 2023 16:06:12 +0000 Subject: [PATCH] add: chuliu_edmonds_one_root_with_constrains --- .../parser/utils/chuliu_edmonds_utils.py | 18 ++++++++++--- tests/chuliu_edmonds_utils_test.py | 26 +++++++++++++++++++ tests/load_data_utils_test.py | 8 +++--- 3 files changed, 45 insertions(+), 7 deletions(-) create mode 100644 tests/chuliu_edmonds_utils_test.py diff --git a/BertForDeprel/parser/utils/chuliu_edmonds_utils.py b/BertForDeprel/parser/utils/chuliu_edmonds_utils.py index f63f9cd..3effdb6 100755 --- a/BertForDeprel/parser/utils/chuliu_edmonds_utils.py +++ b/BertForDeprel/parser/utils/chuliu_edmonds_utils.py @@ -1,4 +1,6 @@ +from typing import List, Tuple import numpy as np +from numpy.typing import NDArray def tarjan(tree): """""" @@ -45,7 +47,6 @@ def strong_connect(i): def chuliu_edmonds(scores): """""" - np.fill_diagonal(scores, -float('inf')) # prevent self-loops scores[0] = -float('inf') scores[0,0] = 0 @@ -126,7 +127,7 @@ def chuliu_edmonds(scores): return new_tree #=============================================================== -def chuliu_edmonds_one_root(scores): +def chuliu_edmonds_one_root(scores: NDArray): """""" scores = scores.astype(np.float64) @@ -161,4 +162,15 @@ def set_root(scores, root): f.write('{}: {}, {}\n'.format(tree, scores, roots_to_try)) f.write('{}: {}, {}, {}\n'.format(_tree, _scores, tree_probs, tree_score)) raise - return best_tree \ No newline at end of file + return best_tree + + +def chuliu_edmonds_one_root_with_constrains(scores: NDArray, forced_relations: List[Tuple] = []): + """ + forced_relations: List[Tuple] : List of (i, j) tuples, the i-eme index will be forced to be dependant of j-eme token + """ + if len(forced_relations): + scores = scores.copy() + for forced_relation in forced_relations: + scores[forced_relation[0], forced_relation[1]] += 1000 + return chuliu_edmonds_one_root(scores) diff --git a/tests/chuliu_edmonds_utils_test.py b/tests/chuliu_edmonds_utils_test.py new file mode 100644 index 0000000..9859c8a --- /dev/null +++ b/tests/chuliu_edmonds_utils_test.py @@ -0,0 +1,26 @@ +from BertForDeprel.parser.utils.chuliu_edmonds_utils import chuliu_edmonds, chuliu_edmonds_one_root, chuliu_edmonds_one_root_with_constrains +import numpy as np + + +mock_coefs_lists = [ + [-9.9, -6.6, -1.2, -6.5, -8.7, -8.3], + [-14.9, -13.2, 0.5, -13, -9.4, -7.9], + [1.6, -3.8, -7.4, -3.9, -14.9, -20.1], + [-1.8, -4.3, -0.9, -2.6, -6.1, -7.7], + [-9.7, -8.5, -8.7, -1.0, -5.7, -8.8], + [-7.9, -4.9, -4.5, -1.2, -3.0, -4.1] + ] + +mock_coefs_array = np.array(mock_coefs_lists).astype(np.float64) + +def test_chuliu_edmonds_one_root(): + assert chuliu_edmonds_one_root(mock_coefs_array).tolist() == [0, 2, 0, 2, 3, 3] + +def test_chuliu_edmonds_one_root_with_constrains(): + assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array).tolist() == [0, 2, 0, 2, 3, 3] + assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, []).tolist() == [0, 2, 0, 2, 3, 3] + assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(4, 5)]).tolist() == [0, 2, 0, 2, 5, 3] + assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(1, 0)]).tolist() == [0, 0, 1, 2, 3, 3] + assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(3, 0)]).tolist() == [0, 2, 3, 0, 3, 3] + assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(4, 1), (4, 5)]).tolist() == [0, 2, 0, 2, 1, 3] + assert chuliu_edmonds_one_root_with_constrains(mock_coefs_array, [(5, 1), (5, 4)]).tolist() == [0, 2, 0, 2, 3, 43] \ No newline at end of file diff --git a/tests/load_data_utils_test.py b/tests/load_data_utils_test.py index 8e59650..791b067 100644 --- a/tests/load_data_utils_test.py +++ b/tests/load_data_utils_test.py @@ -58,16 +58,16 @@ def test_train_output(): assert dataset[0]["idx"] == 0 assert dataset[0]["uposs"] == [-1, 6, 2, 3, 0, 4, 1, 7, 1, 7, -1, 1, 8, 3, 4] assert dataset[0]["heads"] == [-1, 2, 0, 5, 5, 2, 5, 6, 5, 8, -1, 5, 11, 14, 12] - assert dataset[0]["deprels"] == [-1, 1, 8, 3, 5, 9, 10, 0, 10, 0, -1, 6, 0, 3, 0] + assert dataset[0]["deprels"] == [-1, 2, 8, 4, 6, 9, 10, 1, 10, 1, -1, 7, 1, 4, 1] def test_collate_fn(): dataset = ConlluDataset(PATH_TEST_CONLLU, model_params_test, "train", compute_annotation_schema_if_not_found=True) batch = dataset.collate_fn([dataset[0], dataset[1]]) - assert torch.equal(batch["deprels"], torch.tensor([[-1, 1, 8, 3, 5, 9, 10, 0, 10, 0, -1, 6, 0, 3, 0, -1], - [-1, 1, 8, 2, 9, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]])) + assert torch.equal(batch["deprels"], torch.tensor([[-1, 2, 8, 4, 6, 9, 10, 1, 10, 1, -1, 7, 1, 4, 1, -1], + [-1, 2, 8, 3, 9, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]])) def test_add_prediction_to_sentence_json(): dataset = ConlluDataset(PATH_TEST_CONLLU, model_params_test, "train", compute_annotation_schema_if_not_found=True) - predicted_sentence_json = dataset.add_prediction_to_sentence_json(0, [2,3,4,2,5], [1,2,4,15,4], [5,2,3,4,3]) + predicted_sentence_json = dataset.add_prediction_to_sentence_json(0, [2,3,4,2,5], [0,0,0,0,0], [1,2,4,15,4], [5,2,3,4,3], [2,3,4,2,5], [5,2,3,4,3]) assert predicted_sentence_json["treeJson"]["nodesJson"]["4"]["HEAD"] == 15 \ No newline at end of file