From 882dff96a4a377acfa944eea2b344b680f65785d Mon Sep 17 00:00:00 2001 From: zhiheng huang Date: Sun, 7 Jul 2024 10:09:00 -0700 Subject: [PATCH] Revert "feature Improvement about sentence_transformers(embedding,Reranker)" --- denser_retriever/reranker.py | 4 ++-- denser_retriever/retriever_milvus.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/denser_retriever/reranker.py b/denser_retriever/reranker.py index 2c79649..7d52196 100644 --- a/denser_retriever/reranker.py +++ b/denser_retriever/reranker.py @@ -5,7 +5,7 @@ class Reranker: def __init__(self, rerank_model: str): - self.model = CrossEncoder(rerank_model, trust_remote_code=True, max_length=512) + self.model = CrossEncoder(rerank_model, max_length=512) def rerank(self, query, passages, batch_size, query_id=None): passages_copy = copy.deepcopy(passages) @@ -18,7 +18,7 @@ def rerank(self, query, passages, batch_size, query_id=None): for i in range(0, num_passages, batch_size): batch = passage_texts[i : i + batch_size] - scores = self.model.predict(batch, batch_size=batch_size, convert_to_tensor=True).tolist() + scores = self.model.predict(batch).tolist() for j, passage in enumerate(passages_copy[i : i + batch_size]): score_rerank = scores[j] if isinstance(scores, list) else scores diff --git a/denser_retriever/retriever_milvus.py b/denser_retriever/retriever_milvus.py index fd70def..08efd25 100644 --- a/denser_retriever/retriever_milvus.py +++ b/denser_retriever/retriever_milvus.py @@ -147,13 +147,12 @@ def connect_index(self): with open(fields_file, "r") as file: self.field_cat_to_id, self.field_id_to_cat = json.load(file) - def generate_embedding(self, texts, batch_size=32, query=False): + def generate_embedding(self, texts, query=False): if query and not self.config.one_model: - embeddings = self.model.encode( - texts, batch_size=batch_size, prompt_name="query") + embeddings = self.model.encode(texts, prompt_name="query") # embeddings = self.model.encode(texts, prompt="Represent this sentence for searching relevant passages:") else: - embeddings = self.model.encode(texts, batch_size=batch_size) + embeddings = self.model.encode(texts) return embeddings def ingest(self, doc_or_passage_file, batch_size): @@ -204,8 +203,8 @@ def ingest(self, doc_or_passage_file, batch_size): else: # missing category value fieldss[i].append(-1) record_id += 1 - if len(batch) == batch_size*10: - embeddings = self.generate_embedding(batch, batch_size) + if len(batch) == batch_size: + embeddings = self.generate_embedding(batch) record = [uids, sources, titles, texts, pids, np.array(embeddings)] record += fieldss try: @@ -216,7 +215,7 @@ def ingest(self, doc_or_passage_file, batch_size): ) records_per_file.append(record) - if len(records_per_file) == 10: + if len(records_per_file) == 1000: with open(f"{cache_file}_{record_id}.pkl", "wb") as file: pickle.dump(records_per_file, file) records_per_file = [] @@ -230,7 +229,7 @@ def ingest(self, doc_or_passage_file, batch_size): fieldss = [[] for _ in self.field_types.keys()] if len(batch) > 0: - embeddings = self.generate_embedding(batch, batch_size) + embeddings = self.generate_embedding(batch) record = [uids, sources, titles, texts, pids, np.array(embeddings)] record += fieldss try: