Skip to content

Commit

Permalink
Parallel Processing of Colbert Reranker (#378)
Browse files Browse the repository at this point in the history
* make colbert embedding to batch

* done parallel processing of colbert reranker

* add one batch test

---------

Co-authored-by: jeffrey <[email protected]>
Co-authored-by: Bwook (Byoungwook) Kim <[email protected]>
3 people authored Apr 27, 2024

Unverified

This user has not yet uploaded their public signing key.
1 parent 745b845 commit 96ef3c7
Showing 2 changed files with 98 additions and 32 deletions.
93 changes: 61 additions & 32 deletions autorag/nodes/passagereranker/colbert.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from transformers import AutoModel, AutoTokenizer

from autorag.nodes.passagereranker.base import passage_reranker_node
from autorag.utils.util import process_batch
from autorag.utils.util import flatten_apply, sort_by_scores, select_top_k


@passage_reranker_node
@@ -35,43 +36,71 @@ def colbert_reranker(queries: List[str], contents_list: List[List[str]],
model = AutoModel.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Run async cohere_rerank_pure function
tasks = [get_colbert_score(query, document, model, tokenizer) for query, document, ids in
zip(queries, contents_list, ids_list)]
loop = asyncio.get_event_loop()
score_results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
# get query and content embeddings
query_embedding_list = get_colbert_embedding_batch(queries, model, tokenizer, batch)
content_embedding_list = flatten_apply(get_colbert_embedding_batch, contents_list, model=model, tokenizer=tokenizer,
batch_size=batch)

del model
if torch.cuda.is_available():
torch.cuda.empty_cache()

def rerank_results(contents, ids, scores, top_k):
reranked_content, reranked_id, reranked_score = zip(
*sorted(zip(contents, ids, scores), key=lambda x: x[2], reverse=True))
return list(reranked_content)[:top_k], list(reranked_id)[:top_k], list(reranked_score)[:top_k]
df = pd.DataFrame({
'ids': ids_list,
'query_embedding': query_embedding_list,
'contents': contents_list,
'content_embedding': content_embedding_list,
})
temp_df = df.explode('content_embedding')
temp_df['score'] = temp_df.apply(lambda x: get_colbert_score(x['query_embedding'], x['content_embedding']), axis=1)
df['scores'] = temp_df.groupby(level=0, sort=False)['score'].apply(list).tolist()
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand')
results = select_top_k(df, ['contents', 'ids', 'scores'], top_k)

reranked_contents_list, reranked_ids_list, reranked_scores_list = zip(*list(map(
rerank_results, contents_list, ids_list, score_results, [top_k] * len(contents_list))))
return list(reranked_contents_list), list(reranked_ids_list), list(reranked_scores_list)
return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist()


async def get_colbert_score(query: str, content_list: List[str],
model, tokenizer) -> List[float]:
query_encoding = tokenizer(query, return_tensors="pt")
query_embedding = model(**query_encoding).last_hidden_state
rerank_score_list = []
def get_colbert_embedding_batch(input_strings: List[str],
model, tokenizer, batch_size: int) -> List[np.array]:
encoding = tokenizer(input_strings, return_tensors="pt", padding=True)
input_batches = slice_tokenizer_result(encoding, batch_size)
result_embedding = []
for encoding in input_batches:
result_embedding.append(model(**encoding).last_hidden_state)
total_tensor = torch.cat(result_embedding, dim=0) # shape [batch_size, token_length, embedding_dim]
tensor_results = list(total_tensor.chunk(total_tensor.size()[0]))
return list(map(lambda x: x.detach().numpy(), tensor_results))

for document_text in content_list:
document_encoding = tokenizer(
document_text, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = model(**document_encoding).last_hidden_state

sim_matrix = torch.nn.functional.cosine_similarity(
query_embedding.unsqueeze(2), document_embedding.unsqueeze(1), dim=-1
)
def slice_tokenizer_result(tokenizer_output, batch_size):
input_ids_batches = slice_tensor(tokenizer_output["input_ids"], batch_size)
attention_mask_batches = slice_tensor(tokenizer_output["attention_mask"], batch_size)
token_type_ids_batches = slice_tensor(tokenizer_output.get("token_type_ids", None), batch_size)
return [{"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}
for input_ids, attention_mask, token_type_ids in
zip(input_ids_batches, attention_mask_batches, token_type_ids_batches)]

# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = torch.max(sim_matrix, dim=2)
rerank_score_list.append(torch.mean(max_sim_scores, dim=1))

return list(map(float, rerank_score_list))
def slice_tensor(input_tensor, batch_size):
# Calculate the number of full batches
num_full_batches = input_tensor.size(0) // batch_size

# Slice the tensor into batches
tensor_list = [input_tensor[i * batch_size:(i + 1) * batch_size] for i in range(num_full_batches)]

# Handle the last batch if it's smaller than batch_size
remainder = input_tensor.size(0) % batch_size
if remainder:
tensor_list.append(input_tensor[-remainder:])

return tensor_list


def get_colbert_score(query_embedding: np.array, content_embedding: np.array) -> float:
query_tensor = torch.tensor(query_embedding)
content_tensor = torch.tensor(content_embedding)
sim_matrix = torch.nn.functional.cosine_similarity(
query_tensor.unsqueeze(2), content_tensor.unsqueeze(1), dim=-1
)
max_sim_scores, _ = torch.max(sim_matrix, dim=2)
return float(torch.mean(max_sim_scores, dim=1))
37 changes: 37 additions & 0 deletions tests/autorag/nodes/passagereranker/test_colbert_reranker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import itertools

import pytest
import torch
from transformers import AutoModel, AutoTokenizer

from autorag.nodes.passagereranker import colbert_reranker
from autorag.nodes.passagereranker.colbert import get_colbert_embedding_batch, slice_tensor
from tests.autorag.nodes.passagereranker.test_passage_reranker_base import queries_example, contents_example, \
scores_example, ids_example, base_reranker_test, project_dir, previous_result, base_reranker_node_test
from tests.delete_tests import is_github_action
@@ -15,8 +20,40 @@ def test_colbert_reranker():
base_reranker_test(contents_result, id_result, score_result, top_k)


@pytest.mark.skipif(is_github_action(), reason="Skipping this test on GitHub Actions because it uses local model.")
def test_colbert_reranker_one_batch():
top_k = 2
original_colbert_reranker = colbert_reranker.__wrapped__
contents_result, id_result, score_result \
= original_colbert_reranker(queries_example, contents_example, scores_example, ids_example, top_k, batch=1)
base_reranker_test(contents_result, id_result, score_result, top_k)


@pytest.mark.skipif(is_github_action(), reason="Skipping this test on GitHub Actions because it uses local model.")
def test_colbert_reranker_node():
top_k = 1
result_df = colbert_reranker(project_dir=project_dir, previous_result=previous_result, top_k=top_k)
base_reranker_node_test(result_df, top_k)


@pytest.mark.skipif(is_github_action(), reason="Skipping this test on GitHub Actions because it uses local model.")
def test_colbert_embedding():
contents = list(itertools.chain.from_iterable(contents_example))
model_name = "colbert-ir/colbertv2.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
colbert_embedding = get_colbert_embedding_batch(contents, model, tokenizer, batch_size=2)

assert isinstance(colbert_embedding, list)
assert len(colbert_embedding) == len(contents)
assert colbert_embedding[0].shape == (1, 11, 768)


def test_slice_tensor():
original_tensor = torch.randn(14, 7)
batch_size = 4
resulting_tensor_list = slice_tensor(original_tensor, batch_size)
assert len(resulting_tensor_list) == 4
assert resulting_tensor_list[0].size() == torch.Size([4, 7])
assert resulting_tensor_list[-1].size() == torch.Size([2, 7])

0 comments on commit 96ef3c7

Please sign in to comment.