-
Notifications
You must be signed in to change notification settings - Fork 0
/
CLIP_Processor.py
106 lines (74 loc) · 3.81 KB
/
CLIP_Processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from lib import *
from model import get_model
from helper_function import Helper
helper = Helper()
class CLIP_Processor:
def __init__(self, image_query):
self.dataset_dict, self.image_filenames = helper.dataset_dict, helper.image_filenames
self.index = faiss.IndexFlatIP(512)
self.image_query = image_query
self.model, self.processor = get_model("CLIP")
def get_image_embedding(self, image):
inputs = self.processor(images=[image], return_tensors="pt", padding=True)
outputs = self.model.get_image_features(**inputs)
return outputs.cpu().detach().numpy()
def get_text_embedding(self, text):
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
outputs = self.model.get_text_features(**inputs)
return outputs.cpu().detach().numpy()
def embedding_image_database(self): # => store embedding of image in faiss
SAVE_INTERVAL = 100
os.makedirs("vector_db", exist_ok=True)
path_index_json = "vector_db/index_CLIP.json"
path_index_bin = "vector_db/index_CLIP.bin"
if not os.path.exists(path_index_bin):
for i, file in tqdm.tqdm(enumerate(self.image_filenames)):
image = Image.open(file).convert("RGB")
embedding = self.get_image_embedding(image)
faiss.normalize_L2(embedding)
self.index.add(embedding)
if i % SAVE_INTERVAL == 0:
faiss.write_index(self.index, path_index_bin)
faiss.write_index(self.index, path_index_bin)
else:
self.index = faiss.read_index(path_index_bin)
if not os.path.exists(path_index_json):
with open(path_index_json, "w") as f:
json.dump(self.image_filenames, f)
def Query(self, image_query: str, top_k: int): # => implement query
query_embedding = self.get_image_embedding(image_query)
faiss.normalize_L2(query_embedding)
distance_euclide, indices = self.index.search(query_embedding, top_k)
# print("Distance:", distance_euclide) # đã ranking
# print("Indices:", indices)
path_images = [self.image_filenames[i] for i in indices[0]]
image_names = self.dataset_dict[path_images[0]]
arg_distance = sum(distance_euclide[0]) / len(distance_euclide[0])
if arg_distance < 0.7:
return path_images, f"Chúng tôi không có sản phẩm như vậy || {image_names}", arg_distance
return path_images, image_names, arg_distance
def run (self):
# embedding image
self.embedding_image_database()
# Đo lượng RAM sử dụng trước khi inference
ram_before_infer = psutil.virtual_memory().used / (1024 ** 2) # MB
# query image
path_images, image_names, distance_euclide = self.Query(self.image_query, top_k=5) # truy van anh
# Đo lượng RAM sử dụng sau khi inference
ram_after_infer = psutil.virtual_memory().used / (1024 ** 2) # MB
# plot image after query
helper.plot_results(path_images)
# print(f"RAM Used by Model CLIP to Inference: {ram_after_infer - ram_before_infer:.2f} MB")
return path_images, image_names, distance_euclide
if __name__ == "__main__":
dataset_dir = "test_query"
test_query = [os.path.join(dataset_dir, path_query) for path_query in os.listdir(dataset_dir)]
image_query = Image.open("test_query/2ed921949deaddcb969d301a3f40f993_png_jpg.rf.60e03df4d14d3b19194e7c49b33ed106.jpg")
# Show image query
# plt.imshow(image_query)
# plt.axis('off')
# plt.show()
processor = CLIP_Processor(image_query=image_query)
path_images, image_names, distance_euclide = processor.run()
print(distance_euclide)
print(image_names)