diff --git a/.gitignore b/.gitignore index 8a38644..d704caa 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,4 @@ dmypy.json data/* embeddings/* results/* +models/* diff --git a/backend/my_executors.py b/backend/my_executors.py index ee15566..4bc0ca9 100644 --- a/backend/my_executors.py +++ b/backend/my_executors.py @@ -7,7 +7,7 @@ from transformers import BertModel, BertTokenizer from jina import Executor, requests, Document, DocumentArray -from backend_config import top_k, embeddings_path +from backend_config import top_k, embeddings_path from utils import partition from helpers import log @@ -19,12 +19,17 @@ def __init__(self, **kwargs): log("Initialising ProtBertExecutor.") super().__init__() + model_path = "../models/prot_bert" + if not os.path.exists(model_path): + log(f"Downloading model {model_path}.") + model_path = "Rostlab/prot_bert" + else: + log(f"Using local model: {model_path}") + log("Setting tokenizer.") - tokenizer = BertTokenizer.from_pretrained( - "Rostlab/prot_bert", do_lower_case=False - ) + tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=False) log("Setting model.") - model = BertModel.from_pretrained("Rostlab/prot_bert") + model = BertModel.from_pretrained(model_path) self.tokenizer = tokenizer self.model = model