Skip to content

Commit

Permalink
Merge pull request #15 from denser-org/revert-14-main
Browse files Browse the repository at this point in the history
Revert "feature Improvement about sentence_transformers(embedding,Reranker)"
  • Loading branch information
zhiheng-huang authored Jul 7, 2024
2 parents a7e048e + 882dff9 commit b7ea2e2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions denser_retriever/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class Reranker:
def __init__(self, rerank_model: str):
self.model = CrossEncoder(rerank_model, trust_remote_code=True, max_length=512)
self.model = CrossEncoder(rerank_model, max_length=512)

def rerank(self, query, passages, batch_size, query_id=None):
passages_copy = copy.deepcopy(passages)
Expand All @@ -18,7 +18,7 @@ def rerank(self, query, passages, batch_size, query_id=None):

for i in range(0, num_passages, batch_size):
batch = passage_texts[i : i + batch_size]
scores = self.model.predict(batch, batch_size=batch_size, convert_to_tensor=True).tolist()
scores = self.model.predict(batch).tolist()

for j, passage in enumerate(passages_copy[i : i + batch_size]):
score_rerank = scores[j] if isinstance(scores, list) else scores
Expand Down
15 changes: 7 additions & 8 deletions denser_retriever/retriever_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,12 @@ def connect_index(self):
with open(fields_file, "r") as file:
self.field_cat_to_id, self.field_id_to_cat = json.load(file)

def generate_embedding(self, texts, batch_size=32, query=False):
def generate_embedding(self, texts, query=False):
if query and not self.config.one_model:
embeddings = self.model.encode(
texts, batch_size=batch_size, prompt_name="query")
embeddings = self.model.encode(texts, prompt_name="query")
# embeddings = self.model.encode(texts, prompt="Represent this sentence for searching relevant passages:")
else:
embeddings = self.model.encode(texts, batch_size=batch_size)
embeddings = self.model.encode(texts)
return embeddings

def ingest(self, doc_or_passage_file, batch_size):
Expand Down Expand Up @@ -204,8 +203,8 @@ def ingest(self, doc_or_passage_file, batch_size):
else: # missing category value
fieldss[i].append(-1)
record_id += 1
if len(batch) == batch_size*10:
embeddings = self.generate_embedding(batch, batch_size)
if len(batch) == batch_size:
embeddings = self.generate_embedding(batch)
record = [uids, sources, titles, texts, pids, np.array(embeddings)]
record += fieldss
try:
Expand All @@ -216,7 +215,7 @@ def ingest(self, doc_or_passage_file, batch_size):
)

records_per_file.append(record)
if len(records_per_file) == 10:
if len(records_per_file) == 1000:
with open(f"{cache_file}_{record_id}.pkl", "wb") as file:
pickle.dump(records_per_file, file)
records_per_file = []
Expand All @@ -230,7 +229,7 @@ def ingest(self, doc_or_passage_file, batch_size):
fieldss = [[] for _ in self.field_types.keys()]

if len(batch) > 0:
embeddings = self.generate_embedding(batch, batch_size)
embeddings = self.generate_embedding(batch)
record = [uids, sources, titles, texts, pids, np.array(embeddings)]
record += fieldss
try:
Expand Down

0 comments on commit b7ea2e2

Please sign in to comment.