diff --git a/README.md b/README.md index 87da7c6..37f2fea 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,10 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an ## Updates +- v0.5.0: Added support for the current state-of-the-art rerankers, BAAI's series of `BGE` layerwise LLM rerankers, based on [Gemma](https://huggingface.co/BAAI/bge-reranker-v2.5-gemma2-lightweight) and MiniCPM. These are different from RankGPT, as they're not listwise: the models are repurposed as "cross-encoders", and do output logit scores. - v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan). - v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers. -- v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...) +- v0.3.0: Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...) - v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API - v0.1.2: Voyage reranking API - v0.1.1: Langchain integration fixed! @@ -198,6 +199,7 @@ Models: - ✅ Any standard SentenceTransformer or Transformers cross-encoder - ✅ RankGPT (Available both via the original RankGPT implementation and the improved RankLLM one) - ✅ T5-based pointwise rankers (InRanker, MonoT5...) +- ✅ LLM-based pointwise rankers (BAAI/bge-reranker-v2.5-gemma2-lightweight, etc...) - ✅ Cohere, Jina, Voyage and MixedBread API rerankers - ✅ [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers (ONNX-optimised models, very fast on CPU) - ✅ ColBERT-based reranker - not a model initially designed for reranking, but does perform quite strongly in some cases. Implementation is lightweight, based only on transformers. diff --git a/pyproject.toml b/pyproject.toml index fa8013f..833942b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ packages = [ name = "rerankers" -version = "0.4.0" +version = "0.5.0" description = "A unified API for various document re-ranking models." @@ -60,6 +60,7 @@ all = [ "sentencepiece", "protobuf", "flashrank", + "flash-attn", "nmslib-metabrainz; python_version >= '3.10'", "rank-llm; python_version >= '3.10'" ] @@ -67,6 +68,7 @@ transformers = ["transformers", "torch", "sentencepiece", "protobuf"] api = ["requests"] gpt = ["litellm"] flashrank = ["flashrank"] +llmlayerwise = ["transformers", "torch", "sentencepiece", "protobuf", "flash-attn"] rankllm = [ "nmslib-metabrainz; python_version >= '3.10'", "rank-llm; python_version >= '3.10'" diff --git a/rerankers/__init__.py b/rerankers/__init__.py index 11e0664..b08fdb4 100644 --- a/rerankers/__init__.py +++ b/rerankers/__init__.py @@ -2,4 +2,4 @@ from rerankers.documents import Document __all__ = ["Reranker", "Document"] -__version__ = "0.4.0" +__version__ = "0.5.0" diff --git a/rerankers/models/__init__.py b/rerankers/models/__init__.py index 7c997b1..cd0439d 100644 --- a/rerankers/models/__init__.py +++ b/rerankers/models/__init__.py @@ -45,3 +45,10 @@ AVAILABLE_RANKERS["RankLLMRanker"] = RankLLMRanker except ImportError: pass + +try: + from rerankers.models.llm_layerwise_ranker import LLMLayerWiseRanker + + AVAILABLE_RANKERS["LLMLayerWiseRanker"] = LLMLayerWiseRanker +except ImportError: + pass diff --git a/rerankers/models/llm_layerwise_ranker.py b/rerankers/models/llm_layerwise_ranker.py new file mode 100644 index 0000000..673b7e6 --- /dev/null +++ b/rerankers/models/llm_layerwise_ranker.py @@ -0,0 +1,198 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from rerankers.models.ranker import BaseRanker +from rerankers.documents import Document +from typing import Union, List, Optional +from rerankers.utils import vprint, get_device, get_dtype, prep_docs +from rerankers.results import RankedResults, Result + + +PROMPTS = { + "BAAI/bge-reranker-v2.5-gemma2-lightweight": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.", + "default": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.", +} + +DEFAULT_PARAMS = { + "default": {}, + "BAAI/bge-multilingual-gemma2": {}, + "BAAI/bge-reranker-v2-gemma": {}, + "BAAI/bge-reranker-v2-minicpm-layerwise": {"cutoff_layers": [28]}, + "BAAI/bge-reranker-v2.5-gemma2-lightweight": { + "cutoff_layers": [28], + "compress_ratio": 2, + "compress_layer": [24, 40], + }, +} + + +class LLMLayerWiseRanker(BaseRanker): + def __init__( + self, + model_name_or_path: str = "BAAI/bge-reranker-v2.5-gemma2-lightweight", + max_sequence_length: int = 512, + dtype: Optional[Union[str, torch.dtype]] = None, + device: Optional[Union[str, torch.device]] = None, + batch_size: int = 16, + verbose: int = 1, + prompt: Optional[str] = None, + cutoff_layers: Optional[List[int]] = None, + compress_ratio: Optional[int] = None, + compress_layer: Optional[List[int]] = None, + ): + self.verbose = verbose + self.device = get_device(device, verbose=self.verbose) + self.dtype = get_dtype(dtype, self.device, self.verbose) + self.batch_size = batch_size + + vprint( + f"Loading model {model_name_or_path}, this might take a while...", + self.verbose, + ) + vprint(f"Using device {self.device}.", self.verbose) + vprint(f"Using dtype {self.dtype}.", self.verbose) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + self.max_sequence_length = max_sequence_length + self.tokenizer.model_max_length = self.max_sequence_length + self.tokenizer.padding_side = "right" + + self.model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, trust_remote_code=True, torch_dtype=self.dtype + ).to(self.device) + self.model.eval() + + # Create params dict based on specified values or defaults + params = {} + if cutoff_layers is not None: + params["cutoff_layers"] = cutoff_layers + if compress_ratio is not None: + params["compress_ratio"] = compress_ratio + if compress_layer is not None: + params["compress_layer"] = compress_layer + if not params: + params = DEFAULT_PARAMS.get(model_name_or_path, DEFAULT_PARAMS["default"]) + self.params = params + + self.prompt = prompt + if self.prompt is None: + self.prompt = PROMPTS.get(model_name_or_path, PROMPTS["default"]) + + def _get_inputs(self, pairs, max_sequence_length: int): + prompt = self.prompt + sep = "\n" + prompt_inputs = self.tokenizer( + prompt, return_tensors=None, add_special_tokens=False + )["input_ids"] + sep_inputs = self.tokenizer(sep, return_tensors=None, add_special_tokens=False)[ + "input_ids" + ] + inputs = [] + for query, passage in pairs: + query_inputs = self.tokenizer( + f"A: {query}", + return_tensors=None, + add_special_tokens=False, + max_length=max_sequence_length * 3 // 4, + truncation=True, + ) + passage_inputs = self.tokenizer( + f"B: {passage}", + return_tensors=None, + add_special_tokens=False, + max_length=max_sequence_length, + truncation=True, + ) + item = self.tokenizer.prepare_for_model( + [self.tokenizer.bos_token_id] + query_inputs["input_ids"], + sep_inputs + passage_inputs["input_ids"], + truncation="only_second", + max_length=max_sequence_length, + padding=False, + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=False, + ) + item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs + item["attention_mask"] = [1] * len(item["input_ids"]) + inputs.append(item) + + return self.tokenizer.pad( + inputs, + padding=True, + max_length=max_sequence_length + len(sep_inputs) + len(prompt_inputs), + pad_to_multiple_of=8, + return_tensors="pt", + ) + + @torch.no_grad() + def rank( + self, + query: str, + docs: Union[str, List[str], Document, List[Document]], + doc_ids: Optional[Union[List[str], List[int]]] = None, + metadata: Optional[List[dict]] = None, + batch_size: Optional[int] = None, + max_sequence_length: Optional[int] = None, + ) -> RankedResults: + docs = prep_docs(docs, doc_ids, metadata) + pairs = [(query, doc.text) for doc in docs] + + # Override self.batch_size if explicitly set + if batch_size is None: + batch_size = self.batch_size + + # Same for max_sequence_length + if max_sequence_length is None: + max_sequence_length = self.max_sequence_length + + batched_pairs = [ + pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size) + ] + scores = [] + + for batch in batched_pairs: + inputs = self._get_inputs(batch, max_sequence_length=max_sequence_length) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + outputs = self.model(**inputs, return_dict=True, **self.params) + all_scores = [ + scores[:, -1] + .view( + -1, + ) + .float() + for scores in outputs[0] + ] + batch_scores = all_scores[-1].cpu().numpy().tolist() + + scores.extend(batch_scores) + + ranked_results = [ + Result(document=doc, score=score, rank=idx + 1) + for idx, (doc, score) in enumerate( + sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) + ) + ] + return RankedResults(results=ranked_results, query=query, has_scores=True) + + @torch.no_grad() + def score(self, query: str, doc: str) -> float: + inputs = self._get_inputs( + [(query, doc)], max_sequence_length=self.max_sequence_length + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + outputs = self.model(**inputs, return_dict=True, **self.params) + all_scores = [ + scores[:, -1] + .view( + -1, + ) + .float() + for scores in outputs[0] + ] + score = all_scores[-1].item() + + return score diff --git a/rerankers/reranker.py b/rerankers/reranker.py index 7be39d9..5c7599b 100644 --- a/rerankers/reranker.py +++ b/rerankers/reranker.py @@ -30,7 +30,11 @@ "es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES", }, "flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"}, - "text-embeddings-inference": {"other": "BAAI/bge-reranker-base"} + "text-embeddings-inference": {"other": "BAAI/bge-reranker-base"}, + "llm-layerwise": { + "en": "BAAI/bge-reranker-v2.5-gemma2-lightweight", + "other": "BAAI/bge-reranker-v2.5-gemma2-lightweight", + }, } DEPS_MAPPING = { @@ -42,6 +46,7 @@ "ColBERTRanker": "transformers", "FlashRankRanker": "flashrank", "RankLLMRanker": "rankllm", + "LLMLayerWiseRanker": "transformers", } PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"] @@ -78,6 +83,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) "cross-encoder": "TransformerRanker", "flashrank": "FlashRankRanker", "rankllm": "RankLLMRanker", + "llm-layerwise": "LLMLayerWiseRanker", } return model_mapping.get(explicit_model_type, explicit_model_type) else: @@ -89,7 +95,6 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) "rankllm": "RankLLMRanker", "rankgpt": "RankGPTRanker", "gpt": "RankGPTRanker", - "zephyr": "RankZephyr", "colbert": "ColBERTRanker", "cohere": "APIRanker", "jina": "APIRanker", @@ -99,6 +104,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None) "ms-marco-multibert-l-12": "FlashRankRanker", "vicuna": "RankLLMRanker", "zephyr": "RankLLMRanker", + "bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker", } for key, value in model_mapping.items(): if key in model_name: