-
Notifications
You must be signed in to change notification settings - Fork 0
/
prompt_data.py
139 lines (132 loc) · 5.23 KB
/
prompt_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Handles serialization of labeled entity pairs and saves result into JSON.
"""
from typing import Callable
import json
import random
import numpy as np
from pathlib import Path
from erllm import DATASET_FOLDER_PATH, DATASET_NAMES, PROMPT_DATA_FOLDER_PATH
from erllm.dataset.entity import OrderedEntity
from erllm.dataset.load_ds import load_dataset
def dataset_to_prompt_data(
folder: Path,
save_to: Path,
to_str: Callable[[OrderedEntity], str] = lambda e: e.value_string(),
):
"""
Take labeled entity pairs from CSV and serialize them into string for use in prompt.
Also keep groundtruth and ids.
Args:
folder (Path): Path to the CSV dataset folder.
save_to (Path): Path to save the converted JSON data.
to_str (Callable[[OrderedEntity], str], optional): Function to serilalize an OrderedEntity into a string.
Converts CSV entity pairs to JSON format where
each entry has keys "t", "id0", "id1", "e0", and "e1".
The resulting JSON is saved to '{dataset}.json' in 'save_to'.
Example:
dataset_to_prompt_data(Path("path/to/dataset"), Path("path/to/save"))
"""
dataset = folder.parts[-1]
pairs = load_dataset(folder, use_tqdm=True)
data = []
for truth, e0, e1 in pairs:
data.append(
{"t": truth, "id0": e0.id, "id1": e1.id, "e0": to_str(e0), "e1": to_str(e1)}
)
with open(save_to / f"{dataset}.json", "w") as json_file:
json.dump(data, json_file, indent=2)
CONFIGURATIONS = {
"default": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH),
"dataset_paths": [DATASET_FOLDER_PATH / dataset for dataset in DATASET_NAMES],
"to_str": lambda e: e.value_string(),
},
"with-attr-names": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names",
"dataset_paths": [DATASET_FOLDER_PATH / dataset for dataset in DATASET_NAMES],
"to_str": lambda e: e.ffm_wrangle_string(),
},
"with-attr-names-rnd-order": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names_rnd_order",
"dataset_paths": [
DATASET_FOLDER_PATH / dataset
for dataset in filter(lambda x: "dbpedia" not in x, DATASET_NAMES)
],
"to_str": lambda e: e.ffm_wrangle_string(random_order=True),
},
"with-attr-names-embed-05": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names_embed_05",
"dataset_paths": [
DATASET_FOLDER_PATH / dataset
for dataset in filter(lambda x: "dbpedia" not in x, DATASET_NAMES)
],
"to_str": lambda e: e.embed_values_p(p_move=0.5, random_order=False),
"seed": 123,
},
"with-attr-names-embed-one-ppair": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names_embed_one_ppair",
"dataset_paths": [
DATASET_FOLDER_PATH / dataset
for dataset in filter(lambda x: "dbpedia" not in x, DATASET_NAMES)
],
"to_str": lambda e: e.embed_values_k(k=1, random_order=False),
"seed": 123,
},
"with-attr-names-embed-half": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names_embed_half",
"dataset_paths": [
DATASET_FOLDER_PATH / dataset
for dataset in filter(lambda x: "dbpedia" not in x, DATASET_NAMES)
],
"to_str": lambda e: e.embed_values_freq(freq=0.5, random_order=False),
"seed": 123,
},
"""
"with-attr-names-misfield-one": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names_misfield_one",
"dataset_paths": [
DATASET_FOLDER_PATH / dataset
for dataset in filter(lambda x: "dbpedia" not in x, DATASET_NAMES)
],
"to_str": lambda e: e.misfield_str(k=1, random_order=False),
"seed": 123,
},
"""
"with-attr-names-misfield-half": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names_misfield_half",
"dataset_paths": [
DATASET_FOLDER_PATH / dataset
for dataset in filter(lambda x: "dbpedia" not in x, DATASET_NAMES)
],
"to_str": lambda e: e.misfield_str_freq(freq=0.5, random_order=False),
"seed": 123,
},
"with-attr-names-misfield-all": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "wattr_names_misfield_all",
"dataset_paths": [
DATASET_FOLDER_PATH / dataset
for dataset in filter(lambda x: "dbpedia" not in x, DATASET_NAMES)
],
"to_str": lambda e: e.misfield_str_freq(freq=1, random_order=False),
"seed": 123,
},
"full-dbpedia": {
"save_to": Path(PROMPT_DATA_FOLDER_PATH) / "full_dbpedia",
"dataset_paths": [DATASET_FOLDER_PATH / "dbpedia10k"],
"to_str": lambda e: e.value_string(),
"seed": 123,
},
}
if __name__ == "__main__":
# cfg = CONFIGURATIONS["with-attr-names-misfield-half"]
for cfg_name, cfg in CONFIGURATIONS.items():
if "full-dbpedia" not in cfg_name:
continue
datasets, save_to, to_str = cfg["dataset_paths"], cfg["save_to"], cfg["to_str"]
save_to.mkdir(parents=True, exist_ok=True)
if "seed" in cfg:
random.seed(cfg["seed"])
np.random.seed(cfg["seed"])
for folder in datasets:
dataset_to_prompt_data(folder, save_to, to_str)