Skip to content

Commit

Permalink
contribution: infinity-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Apr 2, 2024
1 parent e941823 commit bdfcf56
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 24 deletions.
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ format: ## Running code formatter: black and isort
@find src -name "*.pyi" ! -name "*_pb2*" -exec black --pyi --config pyproject.toml {} \;
@echo "(ruff) Running fix only..."
@ruff check src docs tests --fix-only
format-check:
@echo "(isort) Checking import order..."
@isort --check .
@echo "(black) Checking code formatting..."
@black --config pyproject.toml --check src tests docs
@echo "(ruff) Linting development project..."
@ruff check src docs tests --fix-only
lint: ## Running lint checker: ruff
@echo "(ruff) Linting development project..."
@ruff check src docs tests
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dynamic = ["version", "readme"]
[project.optional-dependencies]
all = [
"sentence-transformers",
"infinity_emb[all]",
]

[tool.setuptools]
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-xdist[psutil]
pytest-asyncio
llama_index
pytest-asyncio
140 changes: 116 additions & 24 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import typing as t
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import field
from typing import List

Expand All @@ -15,6 +15,9 @@

DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5"

if t.TYPE_CHECKING:
from torch import Tensor


class BaseRagasEmbeddings(Embeddings, ABC):
run_config: RunConfig
Expand All @@ -26,7 +29,7 @@ async def embed_text(self, text: str, is_async=True) -> List[float]:
async def embed_texts(
self, texts: List[str], is_async: bool = True
) -> t.List[t.List[float]]:
if is_async:
if is_async and hasattr(self, "aembed_documents"):
aembed_documents_with_retry = add_async_retry(
self.aembed_documents, self.run_config
)
Expand All @@ -41,6 +44,9 @@ async def embed_texts(
def set_run_config(self, run_config: RunConfig):
self.run_config = run_config

@abstractmethod
def embed_documents(self, texts: List[str]) -> List[List[float]]: ...


class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
def __init__(
Expand All @@ -60,7 +66,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_query(self, text: str) -> List[float]:
return await self.embeddings.aembed_query(text)

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: List[str]):
return await self.embeddings.aembed_documents(texts)

def set_run_config(self, run_config: RunConfig):
Expand Down Expand Up @@ -110,47 +116,133 @@ def __post_init__(self):
)

if self.is_cross_encoder:
self.model = sentence_transformers.CrossEncoder(
self._ce = sentence_transformers.CrossEncoder(
self.model_name, **self.model_kwargs
)
self.model = self._ce
self.is_cross_encoder = True
else:
self.model = sentence_transformers.SentenceTransformer(
self._st = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
self.model = self._st
self.is_cross_encoder = False

# ensure outputs are tensors
if "convert_to_tensor" not in self.encode_kwargs:
self.encode_kwargs["convert_to_tensor"] = True
self.encode_kwargs["convert_to_tensor"] = True

def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
from sentence_transformers.SentenceTransformer import SentenceTransformer
from torch import Tensor

assert isinstance(
self.model, SentenceTransformer
), "Model is not of the type Bi-encoder"
embeddings = self.model.encode(
assert not self.is_cross_encoder, "Model is not of the type Bi-encoder"
embeddings: Tensor = self._st.encode( # type: ignore
texts, normalize_embeddings=True, **self.encode_kwargs
)

assert isinstance(embeddings, Tensor)
return embeddings.tolist()

def predict(self, texts: List[List[str]]) -> List[List[float]]:
from sentence_transformers.cross_encoder import CrossEncoder
from torch import Tensor
assert self.is_cross_encoder, "Model is not of the type CrossEncoder"
predictions: Tensor = self.model.predict(texts, **self.encode_kwargs) # type: ignore
return predictions.tolist()


@dataclass
class InfinityEmbeddings(BaseRagasEmbeddings):
"""Infinity embeddings using infinity_emb package.
usage:
```python
embedding_engine = InfinityEmbeddings(model_name="BAAI/bge-small-en-v1.5")
async with embedding_engine:
embeddings = await embedding_engine.aembed_documents(
["Paris is in France", "The capital of France is Paris", "Infintiy batches embeddings on the fly"]
)
assert isinstance(
self.model, CrossEncoder
), "Model is not of the type CrossEncoder"
reranking_engine = InfinityEmbeddings(model_name="BAAI/bge-reranker-base")
async with reranking_engine:
rankings = await reranking_engine.arerank("Where is Paris?", ["Paris is in France", "I don't know the capital of Paris.", "Dummy sentence"])
```
"""

predictions = self.model.predict(texts, **self.encode_kwargs)
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
infinity_engine_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
"""infinity engine keyword arguments.
{
batch_size: int = 64
revision: str | None = None,
trust_remote_code: bool = True,
engine: str = torch | optimum | ctranslate2
model_warmup: bool = False
vector_disk_cache_path: str = ""
device: Device | str = "auto"
lengths_via_tokenize: bool = False
}
"""

assert isinstance(predictions, Tensor)
return predictions.tolist()
def __post_init__(self):
try:
import infinity_emb
except ImportError as exc:
raise ImportError(
"Could not import infinity_emb python package. "
"Please install it with `pip install infinity-emb[torch,optimum]>=0.0.32`."
) from exc
self.engine = infinity_emb.AsyncEmbeddingEngine(
model_name_or_path=self.model_name, **self.infinity_engine_kwargs
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError(
"Infinity embeddings does not support sync embeddings"
)

def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

async def aembed_documents(self, texts: List[str]) -> t.List[t.List[float]]:
"""vectorize documents using an embedding model and return embeddings"""
await self.__aenter__()
if "embed" not in self.engine.capabilities:
raise ValueError(
f"Model={self.model_name} does not have `embed` capability, but only {self.engine.capabilities}. "
"Try a different model, e.g. `model_name=BAAI/bge-small-en-v1.5`"
)
# return embeddings
embeddings, _ = await self.engine.embed(sentences=texts)
return np.array(embeddings).tolist()

async def aembed_query(self, text: str) -> t.List[float]:
"""vectorize a query using an embedding model and return embeddings"""
embeddings = await self.aembed_documents([text])
return embeddings[0]

async def arerank(self, query: str, docs: List[str]) -> List[float]:
"""rerank documents against a single query and return scores for each document"""
await self.__aenter__()
if "rerank" not in self.engine.capabilities:
raise ValueError(
f"Model={self.model_name} does not have `rerank` capability, but only {self.engine.capabilities}. "
"Try a different model, e.g. `model_name=mixedbread-ai/mxbai-rerank-base-v1`"
)
# return predictions
rankings, _ = await self.engine.rerank(query=query, docs=docs)
return rankings

async def __aenter__(self, *args, **kwargs):
if not self.engine.running:
await self.engine.astart()

async def __aexit__(self, *args, **kwargs):
if self.engine.running:
await self.engine.astop()

def __del__(self, *args, **kwargs):
if self.engine.running:
if not hasattr(self.engine, "stop"):
raise AttributeError("Engine does not have a stop method")
self.engine.stop()


def embedding_factory(run_config: t.Optional[RunConfig] = None) -> BaseRagasEmbeddings:
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1 +1,52 @@
from __future__ import annotations

import numpy as np
import pytest

from ragas.embeddings.base import InfinityEmbeddings

try:
import infinity_emb # noqa
import torch # noqa

INFINITY_AVAILABLE = True
except ImportError:
INFINITY_AVAILABLE = False


@pytest.mark.skipif(not INFINITY_AVAILABLE, reason="infinity_emb is not installed.")
@pytest.mark.asyncio
async def test_basic_embedding():
embedding_engine = InfinityEmbeddings(model_name="BAAI/bge-small-en-v1.5")
async with embedding_engine:
embeddings = await embedding_engine.aembed_documents(
[
"Paris is in France",
"The capital of France is Paris",
"Infintiy batches embeddings on the fly",
]
* 20
)
assert isinstance(embeddings, list)
array = np.array(embeddings)
assert array.shape == (60, 384)
assert array[0] @ array[1] > array[0] @ array[2]


@pytest.mark.skipif(not INFINITY_AVAILABLE, reason="infinity_emb is not installed.")
@pytest.mark.asyncio
async def test_rerank():
rerank_engine = InfinityEmbeddings(model_name="BAAI/bge-reranker-base")

async with rerank_engine:
rankings = await rerank_engine.arerank(
"Where is Paris?",
[
"Paris is in France",
"I don't know the capital of Paris.",
"Dummy sentence",
],
)
assert len(rankings) == 3
assert rankings[0] > rankings[1]
assert rankings[0] > rankings[2]

0 comments on commit bdfcf56

Please sign in to comment.