-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader_RE.py
54 lines (45 loc) · 1.7 KB
/
dataloader_RE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# the dataset loader for the manually tagging dataset
import numpy as np
import json
import os
from tqdm import tqdm
import ipdb
class JointDataset():
def __init__(self, data_root, split):
if split == 'test':
self.tar_file = 'test.json'
if split == 'train':
self.tar_file = 'train.json'
if split == 'eval':
self.tar_file = 'dev.json'
if split == 'test_arxiv':
self.tar_file = 'test_arxiv.json'
with open(os.path.join(data_root, self.tar_file)) as json_file:
all_data = json.load(json_file)
sents, pmcids, psgcy, psgid = [], [], [], []
N = len(all_data)
count = 0
for pmcid, units in all_data.items():
for unit in units:
# if unit.get('flag') == 'j' or unit.get('flag') == 'J' or unit.get('flag') == 'q':
# continue
if len(unit['psgcy']) != len(unit['psgid']):
continue
if len(unit['psgcy']) == 0:
continue
count += 1
pmcids.append(pmcid)
sents.append(unit['sent'])
psgcy.append(unit['psgcy'])
psgid.append(unit['psgid'])
assert len(sents) == len(pmcids) == len(psgcy) == len(psgid)
self.sents = sents
self.pmcids = pmcids
self.psgcy = psgcy
self.psgid = psgid
print("Samples : {}".format(count))
data_root = './Grant-RE'
train_dataset = JointDataset(data_root, 'train')
eval_dataset = JointDataset(data_root, 'eval')
test_dataset = JointDataset(data_root, 'test')
test_arxiv_dataset = JointDataset(data_root, 'test_arxiv')