Skip to content

Commit

Permalink
Merge pull request #21 from AnswerDotAI/improved_colbert
Browse files Browse the repository at this point in the history
Improved colbert
  • Loading branch information
bclavie authored Jul 30, 2024
2 parents b407dc5 + e3e12a8 commit dd61524
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an

## Updates

- v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan).
- v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers.
- v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
- v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [
name = "rerankers"


version = "0.3.1"
version = "0.4.0"

description = "A unified API for various document re-ranking models."

Expand Down
2 changes: 1 addition & 1 deletion rerankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from rerankers.documents import Document

__all__ = ["Reranker", "Document"]
__version__ = "0.3.1"
__version__ = "0.4.0"
198 changes: 181 additions & 17 deletions rerankers/models/colbert_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling."""

import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, AutoModel, AutoTokenizer
from typing import List, Optional, Union
from math import ceil

Expand Down Expand Up @@ -67,17 +68,140 @@ def _insert_token(
return updated_output


def _colbert_score(
q_reps,
p_reps,
q_mask: torch.Tensor,
p_mask: torch.Tensor,
):
def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor):
# calc max sim
# base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py

# Assert that all q_reps are at least as long as the query length
assert (
q_reps.shape[1] >= q_mask.shape[1]
), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}"

token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps)
token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4)
scores, _ = token_scores.max(-1)
scores = scores.sum(1) / q_mask.sum(-1, keepdim=True)
return scores


class ColBERTModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.linear = nn.Linear(config.hidden_size, 128, bias=False)
self.init_weights()

def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True, # Always output hidden states
)

sequence_output = outputs[0]

return self.linear(sequence_output)

def _encode(self, texts: list[str], insert_token_id: int, is_query: bool = False):
encoding = self.tokenizer(
texts,
return_tensors="pt",
padding=True,
max_length=self.max_length - 1, # for insert token
truncation=True,
)
encoding = _insert_token(encoding, insert_token_id) # type: ignore

return scores.sum(1) / q_mask[:, 1:].sum(-1, keepdim=True)
if is_query:
mask_token_id = self.tokenizer.mask_token_id

new_encodings = {"input_ids": [], "attention_mask": []}

for i, input_ids in enumerate(encoding["input_ids"]):
original_length = (
(input_ids != self.tokenizer.pad_token_id).sum().item()
)

# Calculate QLEN dynamically for each query
if original_length % 32 <= 8:
QLEN = original_length + 8
else:
QLEN = ceil(original_length / 32) * 32

if original_length < QLEN:
pad_length = QLEN - original_length
padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length
padded_attention_mask = (
encoding["attention_mask"][i].tolist() + [0] * pad_length
)
else:
padded_input_ids = input_ids[:QLEN].tolist()
padded_attention_mask = encoding["attention_mask"][i][
:QLEN
].tolist()

new_encodings["input_ids"].append(padded_input_ids)
new_encodings["attention_mask"].append(padded_attention_mask)

for key in new_encodings:
new_encodings[key] = torch.tensor(
new_encodings[key], device=self.device
)

encoding = new_encodings

encoding = {key: value.to(self.device) for key, value in encoding.items()}
return encoding

def _query_encode(self, query: list[str]):
return self._encode(query, self.query_token_id, is_query=True)

def _document_encode(self, documents: list[str]):
return self._encode(documents, self.document_token_id)

def _to_embs(self, encoding) -> torch.Tensor:
with torch.no_grad():
# embs = self.model(**encoding).last_hidden_state.squeeze(1)
embs = self.model(**encoding)
if self.normalize:
embs = embs / embs.norm(dim=-1, keepdim=True)
return embs

def _rerank(self, query: str, documents: list[str]) -> list[float]:
query_encoding = self._query_encode([query])
documents_encoding = self._document_encode(documents)
query_embeddings = self._to_embs(query_encoding)
document_embeddings = self._to_embs(documents_encoding)
scores = (
_colbert_score(
query_embeddings,
document_embeddings,
query_encoding["attention_mask"],
documents_encoding["attention_mask"],
)
.cpu()
.tolist()[0]
)
return scores


class ColBERTRanker(BaseRanker):
Expand Down Expand Up @@ -159,14 +283,9 @@ def _colbert_rank(
return scores

def _query_encode(self, query: list[str]):
tokenized_query_length = len(self.tokenizer.encode(query[0]))
max_length = max(
ceil(tokenized_query_length / 16) * 16, self.query_max_length
) # Ensure not smaller than query_max_length
max_length = int(
min(max_length, self.doc_max_length)
) # Ensure not larger than doc_max_length
return self._encode(query, self.query_token_id, max_length)
return self._encode(
query, self.query_token_id, max_length=self.doc_max_length, is_query=True
)

def _document_encode(self, documents: list[str]):
tokenized_doc_lengths = [
Expand All @@ -189,7 +308,13 @@ def _document_encode(self, documents: list[str]):
) # Ensure not larger than doc_max_length
return self._encode(documents, self.document_token_id, max_length)

def _encode(self, texts: list[str], insert_token_id: int, max_length: int):
def _encode(
self,
texts: list[str],
insert_token_id: int,
max_length: int,
is_query: bool = False,
):
encoding = self.tokenizer(
texts,
return_tensors="pt",
Expand All @@ -198,6 +323,45 @@ def _encode(self, texts: list[str], insert_token_id: int, max_length: int):
truncation=True,
)
encoding = _insert_token(encoding, insert_token_id) # type: ignore

if is_query:
mask_token_id = self.tokenizer.mask_token_id

new_encodings = {"input_ids": [], "attention_mask": []}

for i, input_ids in enumerate(encoding["input_ids"]):
original_length = (
(input_ids != self.tokenizer.pad_token_id).sum().item()
)

# Calculate QLEN dynamically for each query
if original_length % 32 <= 8:
QLEN = original_length + 8
else:
QLEN = ceil(original_length / 32) * 32

if original_length < QLEN:
pad_length = QLEN - original_length
padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length
padded_attention_mask = (
encoding["attention_mask"][i].tolist() + [0] * pad_length
)
else:
padded_input_ids = input_ids[:QLEN].tolist()
padded_attention_mask = encoding["attention_mask"][i][
:QLEN
].tolist()

new_encodings["input_ids"].append(padded_input_ids)
new_encodings["attention_mask"].append(padded_attention_mask)

for key in new_encodings:
new_encodings[key] = torch.tensor(
new_encodings[key], device=self.device
)

encoding = new_encodings

encoding = {key: value.to(self.device) for key, value in encoding.items()}
return encoding

Expand Down

0 comments on commit dd61524

Please sign in to comment.