-
Notifications
You must be signed in to change notification settings - Fork 0
/
to_ditto.py
125 lines (107 loc) · 4.35 KB
/
to_ditto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Provides functions for converting labeled pairs of entities to Ditto format and split them into train, validation, and test sets.
"""
from pathlib import Path
from typing import Iterable
from typing import List, Iterable, Tuple, Set
from pathlib import Path
import random
from erllm.dataset.entity import Entity, OrderedEntity
LabeledPairs = Iterable[
Tuple[int | bool, Entity | OrderedEntity, Entity | OrderedEntity]
]
LabeledPairsSet = Set[Tuple[int | bool, Entity | OrderedEntity, Entity | OrderedEntity]]
def to_ditto_task(
train: List[LabeledPairs],
valid: List[LabeledPairs],
test: List[LabeledPairs],
task_folder: Path,
) -> None:
"""
Convert labeled pairs to Ditto format and save them as text files.
Args:
train (List[LabeledPairs]): List of labeled pairs for training.
valid (List[LabeledPairs]): List of labeled pairs for validation.
test (List[LabeledPairs]): List of labeled pairs for testing.
task_folder (Path): Path to the folder where the Ditto files will be saved.
Returns:
None
"""
task_folder.mkdir(parents=True, exist_ok=True)
for labeled_pairs, stem in zip((train, valid, test), ("train", "valid", "test")):
ditto_file = task_folder / f"{stem}.txt"
pairs_to_ditto(labeled_pairs, ditto_file)
def pairs_to_ditto(
labeled_pairs: LabeledPairs,
ditto_file: Path,
) -> None:
"""
Convert labeled pairs of entities to a Ditto file.
Args:
labeled_pairs (Iterable[Tuple[Entity, Entity, int]]): A collection of labeled entity pairs.
ditto_file (Path): The path to the Ditto file to be created.
Returns:
None
"""
with open(ditto_file, "w", encoding="utf-8") as file:
for label, entity0, entity1 in labeled_pairs:
line = f"{entity0.to_ditto_str()}\t{entity1.to_ditto_str()}\t{int(label)}\n"
file.write(line)
def ditto_split(
labeled_pairs: LabeledPairs,
train_fraction: float,
valid_fraction: float,
seed: int,
) -> Tuple[LabeledPairs, LabeledPairs, LabeledPairs]:
"""
Split the labeled pairs into train, validation, and remaining pairs.
Args:
labeled_pairs (LabeledPairs): The set of labeled pairs.
train_fraction (float): The fraction of pairs to be used for training.
valid_fraction (float): The fraction of pairs to be used for validation.
seed (int): The seed value for randomization.
Returns:
Tuple[LabeledPairs, LabeledPairs, LabeledPairs]: A tuple containing the train, validation, and remaining pairs.
"""
labeled_pairs = set(labeled_pairs)
N = len(labeled_pairs)
pos_ratio = sum(label for label, _, _ in labeled_pairs) / len(labeled_pairs)
N_train_valid = int(round(len(labeled_pairs) * (train_fraction + valid_fraction)))
N_train = int(round(len(labeled_pairs) * train_fraction))
N_valid = N_train_valid - N_train
train = sample(labeled_pairs, pos_ratio, N_train, seed)
labeled_pairs.difference_update(train)
valid = sample(labeled_pairs, pos_ratio, N_valid, seed)
labeled_pairs.difference_update(valid)
assert len(train) + len(valid) + len(labeled_pairs) == N
assert train.intersection(valid) == set()
assert train.intersection(labeled_pairs) == set()
assert valid.intersection(labeled_pairs) == set()
return train, valid, labeled_pairs
def sample(
labeled_pairs: LabeledPairsSet, pos_ratio: float, N: int, seed: int
) -> LabeledPairsSet:
"""
Randomly samples a subset of labeled pairs from the given set, with a specified positive ratio.
Args:
labeled_pairs (LabeledPairsSet): The set of labeled pairs.
pos_ratio (float): The desired ratio of positive pairs in the sampled subset.
N (int): The total number of pairs to sample.
seed (int): The seed value for the random number generator.
Returns:
LabeledPairsSet: The sampled subset of labeled pairs.
"""
random.seed(seed)
N_pos = int(round(N * pos_ratio))
N_neg = N - N_pos
# sample N_train_pos rows with label 1 and N_train_neg rows with label 0
pos = [pair for pair in labeled_pairs if pair[0] == 1]
neg = [pair for pair in labeled_pairs if pair[0] == 0]
pos_sample = set(
random.sample(
pos,
N_pos,
)
)
neg_sample = set(random.sample(neg, N_neg))
return pos_sample.union(neg_sample)