Skip to content

Commit

Permalink
Merge pull request #14 from NLPJCL/main
Browse files Browse the repository at this point in the history
feature Improvement about sentence_transformers(embedding,Reranker)
  • Loading branch information
zhiheng-huang authored Jul 7, 2024
2 parents 8f4b21d + 022e859 commit a7e048e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
4 changes: 2 additions & 2 deletions denser_retriever/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class Reranker:
def __init__(self, rerank_model: str):
self.model = CrossEncoder(rerank_model, max_length=512)
self.model = CrossEncoder(rerank_model, trust_remote_code=True, max_length=512)

def rerank(self, query, passages, batch_size, query_id=None):
passages_copy = copy.deepcopy(passages)
Expand All @@ -18,7 +18,7 @@ def rerank(self, query, passages, batch_size, query_id=None):

for i in range(0, num_passages, batch_size):
batch = passage_texts[i : i + batch_size]
scores = self.model.predict(batch).tolist()
scores = self.model.predict(batch, batch_size=batch_size, convert_to_tensor=True).tolist()

for j, passage in enumerate(passages_copy[i : i + batch_size]):
score_rerank = scores[j] if isinstance(scores, list) else scores
Expand Down
15 changes: 8 additions & 7 deletions denser_retriever/retriever_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,13 @@ def connect_index(self):
with open(fields_file, "r") as file:
self.field_cat_to_id, self.field_id_to_cat = json.load(file)

def generate_embedding(self, texts, query=False):
def generate_embedding(self, texts, batch_size=32, query=False):
if query and not self.config.one_model:
embeddings = self.model.encode(texts, prompt_name="query")
embeddings = self.model.encode(
texts, batch_size=batch_size, prompt_name="query")
# embeddings = self.model.encode(texts, prompt="Represent this sentence for searching relevant passages:")
else:
embeddings = self.model.encode(texts)
embeddings = self.model.encode(texts, batch_size=batch_size)
return embeddings

def ingest(self, doc_or_passage_file, batch_size):
Expand Down Expand Up @@ -203,8 +204,8 @@ def ingest(self, doc_or_passage_file, batch_size):
else: # missing category value
fieldss[i].append(-1)
record_id += 1
if len(batch) == batch_size:
embeddings = self.generate_embedding(batch)
if len(batch) == batch_size*10:
embeddings = self.generate_embedding(batch, batch_size)
record = [uids, sources, titles, texts, pids, np.array(embeddings)]
record += fieldss
try:
Expand All @@ -215,7 +216,7 @@ def ingest(self, doc_or_passage_file, batch_size):
)

records_per_file.append(record)
if len(records_per_file) == 1000:
if len(records_per_file) == 10:
with open(f"{cache_file}_{record_id}.pkl", "wb") as file:
pickle.dump(records_per_file, file)
records_per_file = []
Expand All @@ -229,7 +230,7 @@ def ingest(self, doc_or_passage_file, batch_size):
fieldss = [[] for _ in self.field_types.keys()]

if len(batch) > 0:
embeddings = self.generate_embedding(batch)
embeddings = self.generate_embedding(batch, batch_size)
record = [uids, sources, titles, texts, pids, np.array(embeddings)]
record += fieldss
try:
Expand Down

0 comments on commit a7e048e

Please sign in to comment.