-
Notifications
You must be signed in to change notification settings - Fork 6
/
load_data.py
103 lines (92 loc) · 3.17 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import pickle as pickle
import os
import pandas as pd
import torch
import numpy as np
from entity_marker import *
class RE_Dataset(torch.utils.data.Dataset):
""" Dataset 구성을 위한 class."""
def __init__(self, pair_dataset, labels):
self.pair_dataset = pair_dataset
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx].clone().detach() for key, val in self.pair_dataset.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def preprocessing_dataset(dataset):
""" 처음 불러온 csv 파일을 원하는 형태의 DataFrame으로 변경 시켜줍니다."""
subject_entity = []
object_entity = []
for i,j in zip(dataset['subject_entity'], dataset['object_entity']):
i = eval(i)['word']
j = eval(j)['word']
subject_entity.append(i)
object_entity.append(j)
index = np.arange(len(dataset))
out_dataset = pd.DataFrame({'index' : index,'id':dataset['id'], 'sentence':dataset['sentence'],'subject_entity':subject_entity,'object_entity':object_entity,'label':dataset['label'],})
# out_dataset = pd.DataFrame({'id':dataset['id'], 'sentence':dataset['sentence'],'subject_entity':subject_entity,'object_entity':object_entity,'label':dataset['label'],})
return out_dataset
def load_data(dataset_dir):
""" csv 파일을 경로에 맡게 불러 옵니다. """
pd_dataset = pd.read_csv(dataset_dir)
dataset = preprocessing_dataset(pd_dataset)
return dataset
def tokenized_dataset(dataset, tokenizer):
""" tokenizer에 따라 sentence를 tokenizing 합니다."""
concat_entity = []
# if dataset['ob_type'] in dataset.columns :
for e01, e02 in zip(dataset['subject_entity'], dataset['object_entity']):
temp = ''
temp = e01 + '[SEP]' + e02
concat_entity.append(temp)
if "roberta" in tokenizer.name_or_path:
tokenized_sentences = tokenizer(
concat_entity,
list(dataset['sentence']),
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
add_special_tokens=True,
return_token_type_ids=False
)
else:
tokenized_sentences = tokenizer(
concat_entity,
list(dataset['sentence']),
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
add_special_tokens=True
)
return tokenized_sentences
def marker_tokenized_dataset(dataset, tokenizer):
""" tokenizer에 따라 sentence를 tokenizing 합니다."""
# if dataset['ob_type'] in dataset.columns :
# if "roberta" in tokenizer.name_or_path:
# tokenized_sentence=[]
dataset = add_entity_mark(dataset)
print('add_entity_mark')
if "roberta" in tokenizer.name_or_path:
tokenized_sentences = tokenizer(
dataset,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
add_special_tokens=True,
return_token_type_ids=False
)
else:
tokenized_sentences = tokenizer(
dataset,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
add_special_tokens=True
)
return tokenized_sentences