Skip to content

Commit

Permalink
Refactor transformers_service to reduce exposing API
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Dec 28, 2023
1 parent 6a8f28a commit aae4beb
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions app/Services/transformers_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Using device: {}; CLIP Model: {}, BERT Model: {}",
self.device, config.clip.model, config.ocr_search.bert_model)
self.clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device)
self.clip_processor = CLIPProcessor.from_pretrained(config.clip.model)
self._clip_model = CLIPModel.from_pretrained(config.clip.model).to(self.device)
self._clip_processor = CLIPProcessor.from_pretrained(config.clip.model)
logger.success("CLIP Model loaded successfully")
if config.ocr_search.enable:
self.bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device)
self.bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model)
self._bert_model = BertModel.from_pretrained(config.ocr_search.bert_model).to(self.device)
self._bert_tokenizer = BertTokenizer.from_pretrained(config.ocr_search.bert_model)
logger.success("BERT Model loaded successfully")
else:
logger.info("OCR search is disabled. Skipping OCR and BERT model loading.")
Expand All @@ -34,32 +34,30 @@ def get_image_vector(self, image: Image.Image) -> ndarray:
image = image.convert("RGB")
logger.info("Processing image...")
start_time = time()
inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
inputs = self._clip_processor(images=image, return_tensors="pt").to(self.device)
logger.success("Image processed, now inferencing with CLIP model...")
outputs: FloatTensor = self.clip_model.get_image_features(**inputs)
outputs: FloatTensor = self._clip_model.get_image_features(**inputs)
logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
logger.info("Norm: {}", outputs.norm(dim=-1).item())
outputs /= outputs.norm(dim=-1, keepdim=True)
return outputs.numpy(force=True).reshape(-1)

@no_grad()
def get_text_vector(self, text: str) -> ndarray:
logger.info("Processing text...")
start_time = time()
inputs = self.clip_processor(text=text, return_tensors="pt").to(self.device)
inputs = self._clip_processor(text=text, return_tensors="pt").to(self.device)
logger.success("Text processed, now inferencing with CLIP model...")
outputs: FloatTensor = self.clip_model.get_text_features(**inputs)
outputs: FloatTensor = self._clip_model.get_text_features(**inputs)
logger.success("Inference done. Time elapsed: {:.2f}s", time() - start_time)
logger.info("Norm: {}", outputs.norm(dim=-1).item())
outputs /= outputs.norm(dim=-1, keepdim=True)
return outputs.numpy(force=True).reshape(-1)

@no_grad()
def get_bert_vector(self, text: str) -> ndarray:
start_time = time()
logger.info("Inferencing with BERT model...")
inputs = self.bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device)
outputs = self.bert_model(**inputs)
inputs = self._bert_tokenizer(text.strip().lower(), return_tensors="pt").to(self.device)
outputs = self._bert_model(**inputs)
vector = outputs.last_hidden_state.mean(dim=1).squeeze()
logger.success("BERT inference done. Time elapsed: {:.2f}s", time() - start_time)
return vector.cpu().numpy()
Expand Down

0 comments on commit aae4beb

Please sign in to comment.