diff --git a/README.md b/README.md index a787806..6af2f16 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an ## Updates +- v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan). - v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers. - v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...) - v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API diff --git a/pyproject.toml b/pyproject.toml index bec2405..fa8013f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ packages = [ name = "rerankers" -version = "0.3.1" +version = "0.4.0" description = "A unified API for various document re-ranking models." diff --git a/rerankers/__init__.py b/rerankers/__init__.py index e1cc7ee..11e0664 100644 --- a/rerankers/__init__.py +++ b/rerankers/__init__.py @@ -2,4 +2,4 @@ from rerankers.documents import Document __all__ = ["Reranker", "Document"] -__version__ = "0.3.1" +__version__ = "0.4.0" diff --git a/rerankers/models/colbert_ranker.py b/rerankers/models/colbert_ranker.py index 4533ae3..145fd81 100644 --- a/rerankers/models/colbert_ranker.py +++ b/rerankers/models/colbert_ranker.py @@ -2,7 +2,8 @@ Modifications include packaging into a BaseRanker, dynamic query/doc length and batch size handling.""" import torch -from transformers import AutoModel, AutoTokenizer +import torch.nn as nn +from transformers import BertPreTrainedModel, BertModel, AutoModel, AutoTokenizer from typing import List, Optional, Union from math import ceil @@ -67,17 +68,140 @@ def _insert_token( return updated_output -def _colbert_score( - q_reps, - p_reps, - q_mask: torch.Tensor, - p_mask: torch.Tensor, -): +def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor): + # calc max sim + # base code from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py + + # Assert that all q_reps are at least as long as the query length + assert ( + q_reps.shape[1] >= q_mask.shape[1] + ), f"q_reps should have at least {q_mask.shape[1]} tokens, but has {q_reps.shape[1]}" + token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps) token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4) scores, _ = token_scores.max(-1) + scores = scores.sum(1) / q_mask.sum(-1, keepdim=True) + return scores + + +class ColBERTModel(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.bert = BertModel(config) + self.linear = nn.Linear(config.hidden_size, 128, bias=False) + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + ): + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, # Always output hidden states + ) + + sequence_output = outputs[0] + + return self.linear(sequence_output) + + def _encode(self, texts: list[str], insert_token_id: int, is_query: bool = False): + encoding = self.tokenizer( + texts, + return_tensors="pt", + padding=True, + max_length=self.max_length - 1, # for insert token + truncation=True, + ) + encoding = _insert_token(encoding, insert_token_id) # type: ignore - return scores.sum(1) / q_mask[:, 1:].sum(-1, keepdim=True) + if is_query: + mask_token_id = self.tokenizer.mask_token_id + + new_encodings = {"input_ids": [], "attention_mask": []} + + for i, input_ids in enumerate(encoding["input_ids"]): + original_length = ( + (input_ids != self.tokenizer.pad_token_id).sum().item() + ) + + # Calculate QLEN dynamically for each query + if original_length % 32 <= 8: + QLEN = original_length + 8 + else: + QLEN = ceil(original_length / 32) * 32 + + if original_length < QLEN: + pad_length = QLEN - original_length + padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length + padded_attention_mask = ( + encoding["attention_mask"][i].tolist() + [0] * pad_length + ) + else: + padded_input_ids = input_ids[:QLEN].tolist() + padded_attention_mask = encoding["attention_mask"][i][ + :QLEN + ].tolist() + + new_encodings["input_ids"].append(padded_input_ids) + new_encodings["attention_mask"].append(padded_attention_mask) + + for key in new_encodings: + new_encodings[key] = torch.tensor( + new_encodings[key], device=self.device + ) + + encoding = new_encodings + + encoding = {key: value.to(self.device) for key, value in encoding.items()} + return encoding + + def _query_encode(self, query: list[str]): + return self._encode(query, self.query_token_id, is_query=True) + + def _document_encode(self, documents: list[str]): + return self._encode(documents, self.document_token_id) + + def _to_embs(self, encoding) -> torch.Tensor: + with torch.no_grad(): + # embs = self.model(**encoding).last_hidden_state.squeeze(1) + embs = self.model(**encoding) + if self.normalize: + embs = embs / embs.norm(dim=-1, keepdim=True) + return embs + + def _rerank(self, query: str, documents: list[str]) -> list[float]: + query_encoding = self._query_encode([query]) + documents_encoding = self._document_encode(documents) + query_embeddings = self._to_embs(query_encoding) + document_embeddings = self._to_embs(documents_encoding) + scores = ( + _colbert_score( + query_embeddings, + document_embeddings, + query_encoding["attention_mask"], + documents_encoding["attention_mask"], + ) + .cpu() + .tolist()[0] + ) + return scores class ColBERTRanker(BaseRanker): @@ -159,14 +283,9 @@ def _colbert_rank( return scores def _query_encode(self, query: list[str]): - tokenized_query_length = len(self.tokenizer.encode(query[0])) - max_length = max( - ceil(tokenized_query_length / 16) * 16, self.query_max_length - ) # Ensure not smaller than query_max_length - max_length = int( - min(max_length, self.doc_max_length) - ) # Ensure not larger than doc_max_length - return self._encode(query, self.query_token_id, max_length) + return self._encode( + query, self.query_token_id, max_length=self.doc_max_length, is_query=True + ) def _document_encode(self, documents: list[str]): tokenized_doc_lengths = [ @@ -189,7 +308,13 @@ def _document_encode(self, documents: list[str]): ) # Ensure not larger than doc_max_length return self._encode(documents, self.document_token_id, max_length) - def _encode(self, texts: list[str], insert_token_id: int, max_length: int): + def _encode( + self, + texts: list[str], + insert_token_id: int, + max_length: int, + is_query: bool = False, + ): encoding = self.tokenizer( texts, return_tensors="pt", @@ -198,6 +323,45 @@ def _encode(self, texts: list[str], insert_token_id: int, max_length: int): truncation=True, ) encoding = _insert_token(encoding, insert_token_id) # type: ignore + + if is_query: + mask_token_id = self.tokenizer.mask_token_id + + new_encodings = {"input_ids": [], "attention_mask": []} + + for i, input_ids in enumerate(encoding["input_ids"]): + original_length = ( + (input_ids != self.tokenizer.pad_token_id).sum().item() + ) + + # Calculate QLEN dynamically for each query + if original_length % 32 <= 8: + QLEN = original_length + 8 + else: + QLEN = ceil(original_length / 32) * 32 + + if original_length < QLEN: + pad_length = QLEN - original_length + padded_input_ids = input_ids.tolist() + [mask_token_id] * pad_length + padded_attention_mask = ( + encoding["attention_mask"][i].tolist() + [0] * pad_length + ) + else: + padded_input_ids = input_ids[:QLEN].tolist() + padded_attention_mask = encoding["attention_mask"][i][ + :QLEN + ].tolist() + + new_encodings["input_ids"].append(padded_input_ids) + new_encodings["attention_mask"].append(padded_attention_mask) + + for key in new_encodings: + new_encodings[key] = torch.tensor( + new_encodings[key], device=self.device + ) + + encoding = new_encodings + encoding = {key: value.to(self.device) for key, value in encoding.items()} return encoding