From aae4bebe33726f45e111075cbcf02fb88cacdd64 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Thu, 28 Dec 2023 21:42:21 +0800 Subject: [PATCH] Refactor transformers_service to reduce exposing API --- app/Services/transformers_service.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/app/Services/transformers_service.py b/app/Services/transformers_service.py index edd104a..e745832 100644 --- a/app/Services/transformers_service.py +++ b/app/Services/transformers_service.py @@ -18,12 +18,12 @@ def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Using device: {}; CLIP Model: {}, BERT Model: {}", self.device, config.clip.model, config.ocr_search.bert_model) - self.clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device) - self.clip_processor = CLIPProcessor.from_pretrained(config.clip.model) + self._clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device) + self._clip_processor = CLIPProcessor.from_pretrained(config.clip.model) logger.success("CLIP Model loaded successfully") if config.ocr_search.enable: - self.bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device) - self.bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model) + self._bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device) + self._bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model) logger.success("BERT Model loaded successfully") else: logger.info("OCR search is disabled. Skipping OCR and BERT model loading.") @@ -34,11 +34,10 @@ def get_image_vector(self, image: Image.Image) -> ndarray: image = image.convert("RGB") logger.info("Processing image...") start_time = time() - inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device) + inputs = self._clip_processor(images=image, return_tensors="pt").to(self.device) logger.success("Image processed, now inferencing with CLIP model...") - outputs: FloatTensor = self.clip_model.get_image_features(**inputs) + outputs: FloatTensor = self._clip_model.get_image_features(**inputs) logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time) - logger.info("Norm: {}", outputs.norm(dim=-1).item()) outputs /= outputs.norm(dim=-1, keepdim=True) return outputs.numpy(force=True).reshape(-1) @@ -46,11 +45,10 @@ def get_image_vector(self, image: Image.Image) -> ndarray: def get_text_vector(self, text: str) -> ndarray: logger.info("Processing text...") start_time = time() - inputs = self.clip_processor(text=text, return_tensors="pt").to(self.device) + inputs = self._clip_processor(text=text, return_tensors="pt").to(self.device) logger.success("Text processed, now inferencing with CLIP model...") - outputs: FloatTensor = self.clip_model.get_text_features(**inputs) + outputs: FloatTensor = self._clip_model.get_text_features(**inputs) logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time) - logger.info("Norm: {}", outputs.norm(dim=-1).item()) outputs /= outputs.norm(dim=-1, keepdim=True) return outputs.numpy(force=True).reshape(-1) @@ -58,8 +56,8 @@ def get_text_vector(self, text: str) -> ndarray: def get_bert_vector(self, text: str) -> ndarray: start_time = time() logger.info("Inferencing with BERT model...") - inputs = self.bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device) - outputs = self.bert_model(**inputs) + inputs = self._bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device) + outputs = self._bert_model(**inputs) vector = outputs.last_hidden_state.mean(dim=1).squeeze() logger.success("BERT inference done. Time elapsed: {:.2f}s", time() - start_time) return vector.cpu().numpy()