Skip to content

Commit

Permalink
Merge pull request #25 from AnswerDotAI/feat/llmlayerwise
Browse files Browse the repository at this point in the history
Feat/llmlayerwise
  • Loading branch information
bclavie authored Aug 16, 2024
2 parents e8e83b8 + 810f772 commit e982e80
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 5 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an

## Updates

- v0.5.0: Added support for the current state-of-the-art rerankers, BAAI's series of `BGE` layerwise LLM rerankers, based on [Gemma](https://huggingface.co/BAAI/bge-reranker-v2.5-gemma2-lightweight) and MiniCPM. These are different from RankGPT, as they're not listwise: the models are repurposed as "cross-encoders", and do output logit scores.
- v0.4.0: ColBERT performance improvement! It should now be faster and result in stronger results following implementation of the JaColBERTv2.5 dynamic query length method. This version also now supports HuggingFace's Text-Embedding-Server (TEI) inference as an API reranker option, thanks to [@srisudarsan](https://github.com/srisudarsan).
- v0.3.1: T5 bugfix and native default support for new Portuguese T5 rerankers.
- v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
- v0.3.0: Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
- v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
- v0.1.2: Voyage reranking API
- v0.1.1: Langchain integration fixed!
Expand Down Expand Up @@ -198,6 +199,7 @@ Models:
- ✅ Any standard SentenceTransformer or Transformers cross-encoder
- ✅ RankGPT (Available both via the original RankGPT implementation and the improved RankLLM one)
- ✅ T5-based pointwise rankers (InRanker, MonoT5...)
- ✅ LLM-based pointwise rankers (BAAI/bge-reranker-v2.5-gemma2-lightweight, etc...)
- ✅ Cohere, Jina, Voyage and MixedBread API rerankers
-[FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers (ONNX-optimised models, very fast on CPU)
- ✅ ColBERT-based reranker - not a model initially designed for reranking, but does perform quite strongly in some cases. Implementation is lightweight, based only on transformers.
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [
name = "rerankers"


version = "0.4.0"
version = "0.5.0"

description = "A unified API for various document re-ranking models."

Expand Down Expand Up @@ -60,13 +60,15 @@ all = [
"sentencepiece",
"protobuf",
"flashrank",
"flash-attn",
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
]
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
api = ["requests"]
gpt = ["litellm"]
flashrank = ["flashrank"]
llmlayerwise = ["transformers", "torch", "sentencepiece", "protobuf", "flash-attn"]
rankllm = [
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
Expand Down
2 changes: 1 addition & 1 deletion rerankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from rerankers.documents import Document

__all__ = ["Reranker", "Document"]
__version__ = "0.4.0"
__version__ = "0.5.0"
7 changes: 7 additions & 0 deletions rerankers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@
AVAILABLE_RANKERS["RankLLMRanker"] = RankLLMRanker
except ImportError:
pass

try:
from rerankers.models.llm_layerwise_ranker import LLMLayerWiseRanker

AVAILABLE_RANKERS["LLMLayerWiseRanker"] = LLMLayerWiseRanker
except ImportError:
pass
198 changes: 198 additions & 0 deletions rerankers/models/llm_layerwise_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rerankers.models.ranker import BaseRanker
from rerankers.documents import Document
from typing import Union, List, Optional
from rerankers.utils import vprint, get_device, get_dtype, prep_docs
from rerankers.results import RankedResults, Result


PROMPTS = {
"BAAI/bge-reranker-v2.5-gemma2-lightweight": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.",
"default": "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.",
}

DEFAULT_PARAMS = {
"default": {},
"BAAI/bge-multilingual-gemma2": {},
"BAAI/bge-reranker-v2-gemma": {},
"BAAI/bge-reranker-v2-minicpm-layerwise": {"cutoff_layers": [28]},
"BAAI/bge-reranker-v2.5-gemma2-lightweight": {
"cutoff_layers": [28],
"compress_ratio": 2,
"compress_layer": [24, 40],
},
}


class LLMLayerWiseRanker(BaseRanker):
def __init__(
self,
model_name_or_path: str = "BAAI/bge-reranker-v2.5-gemma2-lightweight",
max_sequence_length: int = 512,
dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[str, torch.device]] = None,
batch_size: int = 16,
verbose: int = 1,
prompt: Optional[str] = None,
cutoff_layers: Optional[List[int]] = None,
compress_ratio: Optional[int] = None,
compress_layer: Optional[List[int]] = None,
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
self.dtype = get_dtype(dtype, self.device, self.verbose)
self.batch_size = batch_size

vprint(
f"Loading model {model_name_or_path}, this might take a while...",
self.verbose,
)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)

self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
)
self.max_sequence_length = max_sequence_length
self.tokenizer.model_max_length = self.max_sequence_length
self.tokenizer.padding_side = "right"

self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True, torch_dtype=self.dtype
).to(self.device)
self.model.eval()

# Create params dict based on specified values or defaults
params = {}
if cutoff_layers is not None:
params["cutoff_layers"] = cutoff_layers
if compress_ratio is not None:
params["compress_ratio"] = compress_ratio
if compress_layer is not None:
params["compress_layer"] = compress_layer
if not params:
params = DEFAULT_PARAMS.get(model_name_or_path, DEFAULT_PARAMS["default"])
self.params = params

self.prompt = prompt
if self.prompt is None:
self.prompt = PROMPTS.get(model_name_or_path, PROMPTS["default"])

def _get_inputs(self, pairs, max_sequence_length: int):
prompt = self.prompt
sep = "\n"
prompt_inputs = self.tokenizer(
prompt, return_tensors=None, add_special_tokens=False
)["input_ids"]
sep_inputs = self.tokenizer(sep, return_tensors=None, add_special_tokens=False)[
"input_ids"
]
inputs = []
for query, passage in pairs:
query_inputs = self.tokenizer(
f"A: {query}",
return_tensors=None,
add_special_tokens=False,
max_length=max_sequence_length * 3 // 4,
truncation=True,
)
passage_inputs = self.tokenizer(
f"B: {passage}",
return_tensors=None,
add_special_tokens=False,
max_length=max_sequence_length,
truncation=True,
)
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs["input_ids"],
sep_inputs + passage_inputs["input_ids"],
truncation="only_second",
max_length=max_sequence_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False,
)
item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs
item["attention_mask"] = [1] * len(item["input_ids"])
inputs.append(item)

return self.tokenizer.pad(
inputs,
padding=True,
max_length=max_sequence_length + len(sep_inputs) + len(prompt_inputs),
pad_to_multiple_of=8,
return_tensors="pt",
)

@torch.no_grad()
def rank(
self,
query: str,
docs: Union[str, List[str], Document, List[Document]],
doc_ids: Optional[Union[List[str], List[int]]] = None,
metadata: Optional[List[dict]] = None,
batch_size: Optional[int] = None,
max_sequence_length: Optional[int] = None,
) -> RankedResults:
docs = prep_docs(docs, doc_ids, metadata)
pairs = [(query, doc.text) for doc in docs]

# Override self.batch_size if explicitly set
if batch_size is None:
batch_size = self.batch_size

# Same for max_sequence_length
if max_sequence_length is None:
max_sequence_length = self.max_sequence_length

batched_pairs = [
pairs[i : i + batch_size] for i in range(0, len(pairs), batch_size)
]
scores = []

for batch in batched_pairs:
inputs = self._get_inputs(batch, max_sequence_length=max_sequence_length)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

outputs = self.model(**inputs, return_dict=True, **self.params)
all_scores = [
scores[:, -1]
.view(
-1,
)
.float()
for scores in outputs[0]
]
batch_scores = all_scores[-1].cpu().numpy().tolist()

scores.extend(batch_scores)

ranked_results = [
Result(document=doc, score=score, rank=idx + 1)
for idx, (doc, score) in enumerate(
sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
)
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

@torch.no_grad()
def score(self, query: str, doc: str) -> float:
inputs = self._get_inputs(
[(query, doc)], max_sequence_length=self.max_sequence_length
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}

outputs = self.model(**inputs, return_dict=True, **self.params)
all_scores = [
scores[:, -1]
.view(
-1,
)
.float()
for scores in outputs[0]
]
score = all_scores[-1].item()

return score
10 changes: 8 additions & 2 deletions rerankers/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
"es": "AdrienB134/ColBERTv2.0-spanish-mmarcoES",
},
"flashrank": {"en": "ms-marco-MiniLM-L-12-v2", "other": "ms-marco-MultiBERT-L-12"},
"text-embeddings-inference": {"other": "BAAI/bge-reranker-base"}
"text-embeddings-inference": {"other": "BAAI/bge-reranker-base"},
"llm-layerwise": {
"en": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
"other": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
},
}

DEPS_MAPPING = {
Expand All @@ -42,6 +46,7 @@
"ColBERTRanker": "transformers",
"FlashRankRanker": "flashrank",
"RankLLMRanker": "rankllm",
"LLMLayerWiseRanker": "transformers",
}

PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"]
Expand Down Expand Up @@ -78,6 +83,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"cross-encoder": "TransformerRanker",
"flashrank": "FlashRankRanker",
"rankllm": "RankLLMRanker",
"llm-layerwise": "LLMLayerWiseRanker",
}
return model_mapping.get(explicit_model_type, explicit_model_type)
else:
Expand All @@ -89,7 +95,6 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"rankllm": "RankLLMRanker",
"rankgpt": "RankGPTRanker",
"gpt": "RankGPTRanker",
"zephyr": "RankZephyr",
"colbert": "ColBERTRanker",
"cohere": "APIRanker",
"jina": "APIRanker",
Expand All @@ -99,6 +104,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"ms-marco-multibert-l-12": "FlashRankRanker",
"vicuna": "RankLLMRanker",
"zephyr": "RankLLMRanker",
"bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker",
}
for key, value in model_mapping.items():
if key in model_name:
Expand Down

0 comments on commit e982e80

Please sign in to comment.