diff --git a/denser_retriever/reranker.py b/denser_retriever/reranker.py index 7d52196..2c79649 100644 --- a/denser_retriever/reranker.py +++ b/denser_retriever/reranker.py @@ -5,7 +5,7 @@ class Reranker: def __init__(self, rerank_model: str): - self.model = CrossEncoder(rerank_model, max_length=512) + self.model = CrossEncoder(rerank_model, trust_remote_code=True, max_length=512) def rerank(self, query, passages, batch_size, query_id=None): passages_copy = copy.deepcopy(passages) @@ -18,7 +18,7 @@ def rerank(self, query, passages, batch_size, query_id=None): for i in range(0, num_passages, batch_size): batch = passage_texts[i : i + batch_size] - scores = self.model.predict(batch).tolist() + scores = self.model.predict(batch, batch_size=batch_size, convert_to_tensor=True).tolist() for j, passage in enumerate(passages_copy[i : i + batch_size]): score_rerank = scores[j] if isinstance(scores, list) else scores diff --git a/denser_retriever/retriever_milvus.py b/denser_retriever/retriever_milvus.py index 08efd25..fd70def 100644 --- a/denser_retriever/retriever_milvus.py +++ b/denser_retriever/retriever_milvus.py @@ -147,12 +147,13 @@ def connect_index(self): with open(fields_file, "r") as file: self.field_cat_to_id, self.field_id_to_cat = json.load(file) - def generate_embedding(self, texts, query=False): + def generate_embedding(self, texts, batch_size=32, query=False): if query and not self.config.one_model: - embeddings = self.model.encode(texts, prompt_name="query") + embeddings = self.model.encode( + texts, batch_size=batch_size, prompt_name="query") # embeddings = self.model.encode(texts, prompt="Represent this sentence for searching relevant passages:") else: - embeddings = self.model.encode(texts) + embeddings = self.model.encode(texts, batch_size=batch_size) return embeddings def ingest(self, doc_or_passage_file, batch_size): @@ -203,8 +204,8 @@ def ingest(self, doc_or_passage_file, batch_size): else: # missing category value fieldss[i].append(-1) record_id += 1 - if len(batch) == batch_size: - embeddings = self.generate_embedding(batch) + if len(batch) == batch_size*10: + embeddings = self.generate_embedding(batch, batch_size) record = [uids, sources, titles, texts, pids, np.array(embeddings)] record += fieldss try: @@ -215,7 +216,7 @@ def ingest(self, doc_or_passage_file, batch_size): ) records_per_file.append(record) - if len(records_per_file) == 1000: + if len(records_per_file) == 10: with open(f"{cache_file}_{record_id}.pkl", "wb") as file: pickle.dump(records_per_file, file) records_per_file = [] @@ -229,7 +230,7 @@ def ingest(self, doc_or_passage_file, batch_size): fieldss = [[] for _ in self.field_types.keys()] if len(batch) > 0: - embeddings = self.generate_embedding(batch) + embeddings = self.generate_embedding(batch, batch_size) record = [uids, sources, titles, texts, pids, np.array(embeddings)] record += fieldss try: