-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdedup_prompts_metdata_format.py
133 lines (105 loc) · 4.69 KB
/
dedup_prompts_metdata_format.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import json
import torch
from sentence_transformers import SentenceTransformer
if __name__ == "__main__":
cross_split_dedup = True
thres = 0.85
chunk_size = 1000
if not cross_split_dedup:
input_path = "./diffusion/image-to-prompt-train-valid-split-v4/validation/metadata_concat.jsonl"
output_path = "./diffusion/image-to-prompt-train-valid-split-v4/validation/metadata_concat_dedup.jsonl"
st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2").cuda()
cosim = torch.nn.CosineSimilarity(dim=1, eps=1e-7)
items = []
prompts = []
with open(input_path) as f:
for idx, line in enumerate(f):
item = json.loads(line)
items.append(item)
prompts.append(item["text"])
embeds = torch.tensor(st_model.encode(prompts)).cuda()
print("embeds created")
deleted_indices = set()
for i in range(0, embeds.size(0), chunk_size):
start = i
end = i + chunk_size
_embeds = embeds[start:end]
similarities = torch.mm(_embeds, embeds.T)
indices = torch.nonzero(similarities > thres, as_tuple=True)
pairs = []
for x, y in zip(indices[0].tolist(), indices[1].tolist()):
_x = x + start
if _x == y:
continue
pairs.append((_x, y))
deleted_indices.update(set([p[1] for p in pairs]))
print("deleted ids:", len(deleted_indices))
results = []
for idx, item in enumerate(items):
if idx in deleted_indices:
continue
else:
results.append(json.dumps(item, ensure_ascii=False) + "\n")
with open(output_path, "w", encoding="utf-8") as f:
for idx, line in enumerate(results):
f.write(line)
else:
# train 기준 valid 의 dup을 제거
# 꼭 train/valid 간의 처리가 아니더라도 가능(metadata format만 맞다면)
train_input_path = "./diffusion/image-to-prompt-train-valid-split-v4/train/metadata_concat.jsonl"
valid_input_path = "./diffusion/image-to-prompt-train-valid-split-v4/validation/metadata_concat.jsonl"
file_name = valid_input_path.split("/")[-1]
output_path = valid_input_path.replace(
file_name, file_name.split(".")[0] + "_split_dedup.jsonl"
)
st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2").cuda()
cosim = torch.nn.CosineSimilarity(dim=1, eps=1e-7)
# 한 곳으로 합쳐서 한 번에 계산 후, 인덱스로 구분
items = []
prompts = []
with open(train_input_path) as f:
for idx, line in enumerate(f):
item = json.loads(line)
items.append(item)
prompts.append(item["text"])
#: 학습 데이터 마지막 인덱스
end_of_train_index = len(prompts) - 1
with open(valid_input_path) as f:
for idx, line in enumerate(f):
item = json.loads(line)
items.append(item)
prompts.append(item["text"])
#: 학습 + 벨리데이션 데이터셋에 대한 임베딩
embeds = torch.tensor(st_model.encode(prompts)).cuda()
print("embeds created")
#: 삭제할 인덱스
deleted_indices = set()
for i in range(0, embeds.size(0), chunk_size):
start = i
end = i + chunk_size
_embeds = embeds[start:end]
similarities = torch.mm(_embeds, embeds.T)
indices = torch.nonzero(similarities > thres, as_tuple=True)
pairs = []
for x, y in zip(indices[0].tolist(), indices[1].tolist()):
#: 청크 인덱스가 반영된 시작점 (y는 관계 없음)
_x = x + start
# 기준점은 학습 인덱스만, 비교 대상은 벨리데이션 인덱스만
if _x > end_of_train_index or y <= end_of_train_index:
continue
if _x == y:
continue
pairs.append((_x, y))
deleted_indices.update(set([p[1] for p in pairs]))
print("deleted ids:", len(deleted_indices))
results = []
for idx, item in enumerate(items):
if idx in deleted_indices:
continue
elif idx <= end_of_train_index:
continue
else:
results.append(json.dumps(item, ensure_ascii=False) + "\n")
with open(output_path, "w", encoding="utf-8") as f:
for idx, line in enumerate(results):
f.write(line)