Skip to content

Commit

Permalink
feat/refactor embedders (#383)
Browse files Browse the repository at this point in the history
* refactor embedders

* fix mixedbreadai embedder

* create default behavior for embed_query

* fix default behavior of _embed_query
  • Loading branch information
rbiseck3 authored Feb 13, 2025
1 parent 17a5dd3 commit 442dbfa
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 403 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
## 0.5.3-dev0
## 0.5.3-dev1

### Enhancements

* **Optimize embedder code** - Move duplicate code to base interface, exit early if no elements have text.

### Fixes

## 0.5.2
Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.3-dev0" # pragma: no cover
__version__ = "0.5.3-dev1" # pragma: no cover
6 changes: 6 additions & 0 deletions unstructured_ingest/embed/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def get_async_client(self) -> "AsyncAzureOpenAI":
class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
config: AzureOpenAIEmbeddingConfig

def get_client(self) -> "AzureOpenAI":
return self.config.get_client()


@dataclass
class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
config: AzureOpenAIEmbeddingConfig

def get_client(self) -> "AsyncAzureOpenAI":
return self.config.get_async_client()
12 changes: 11 additions & 1 deletion unstructured_ingest/embed/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
)
from unstructured_ingest.logger import logger
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.v2.errors import ProviderError, RateLimitError, UserAuthError, UserError
from unstructured_ingest.v2.errors import (
ProviderError,
RateLimitError,
UserAuthError,
UserError,
is_internal_error,
)

if TYPE_CHECKING:
from botocore.client import BaseClient
Expand Down Expand Up @@ -54,6 +60,8 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")

def wrap_error(self, e: Exception) -> Exception:
if is_internal_error(e=e):
return e
from botocore.exceptions import ClientError

if isinstance(e, ClientError):
Expand Down Expand Up @@ -148,6 +156,8 @@ def embed_query(self, query: str) -> list[float]:
def embed_documents(self, elements: list[dict]) -> list[dict]:
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
if not elements_with_text:
return elements
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
Expand Down
4 changes: 3 additions & 1 deletion unstructured_ingest/embed/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_encoder_kwargs(self) -> dict:
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
config: HuggingFaceEmbeddingConfig

def embed_query(self, query: str) -> list[float]:
def _embed_query(self, query: str) -> list[float]:
return self._embed_documents(texts=[query])[0]

def _embed_documents(self, texts: list[str]) -> list[list[float]]:
Expand All @@ -58,6 +58,8 @@ def _embed_documents(self, texts: list[str]) -> list[list[float]]:
def embed_documents(self, elements: list[dict]) -> list[dict]:
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
if not elements_with_text:
return elements
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
Expand Down
84 changes: 61 additions & 23 deletions unstructured_ingest/embed/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
from abc import ABC, abstractmethod
from abc import ABC
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

import numpy as np
from pydantic import BaseModel, Field

from unstructured_ingest.utils.data_prep import batch_generator

EMBEDDINGS_KEY = "embeddings"


Expand Down Expand Up @@ -50,21 +51,37 @@ def is_unit_vector(self) -> bool:
exemplary_embedding = self.get_exemplary_embedding()
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)

@abstractmethod
def embed_documents(self, elements: list[dict]) -> list[dict]:
pass
def get_client(self):
raise NotImplementedError

@abstractmethod
def embed_query(self, query: str) -> list[float]:
pass
def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
raise NotImplementedError

def _embed_documents(self, elements: list[str]) -> list[list[float]]:
results = []
for text in elements:
response = self.embed_query(query=text)
results.append(response)
def embed_documents(self, elements: list[dict]) -> list[dict]:
client = self.get_client()
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
texts = [e["text"] for e in elements_with_text]
embeddings = []
try:
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
embeddings = self.embed_batch(client=client, batch=batch)
embeddings.extend(embeddings)
except Exception as e:
raise self.wrap_error(e=e)
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

def _embed_query(self, query: str) -> list[float]:
client = self.get_client()
return self.embed_batch(client=client, batch=[query])[0]

return results
def embed_query(self, query: str) -> list[float]:
try:
return self._embed_query(query=query)
except Exception as e:
raise self.wrap_error(e=e)


@dataclass
Expand All @@ -88,14 +105,35 @@ async def is_unit_vector(self) -> bool:
exemplary_embedding = await self.get_exemplary_embedding()
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)

@abstractmethod
def get_client(self):
raise NotImplementedError

async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
raise NotImplementedError

async def embed_documents(self, elements: list[dict]) -> list[dict]:
pass
client = self.get_client()
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
texts = [e["text"] for e in elements_with_text]
embeddings = []
try:
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
embeddings = await self.embed_batch(client=client, batch=batch)
embeddings.extend(embeddings)
except Exception as e:
raise self.wrap_error(e=e)
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

async def _embed_query(self, query: str) -> list[float]:
client = self.get_client()
embeddings = await self.embed_batch(client=client, batch=[query])
return embeddings[0]

@abstractmethod
async def embed_query(self, query: str) -> list[float]:
pass

async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
results = await asyncio.gather(*[self.embed_query(query=text) for text in elements])
return results
try:
return await self._embed_query(query=query)
except Exception as e:
raise self.wrap_error(e=e)
142 changes: 28 additions & 114 deletions unstructured_ingest/embed/mixedbreadai.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import asyncio
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

from pydantic import Field, SecretStr

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
)
from unstructured_ingest.utils.data_prep import batch_generator
from unstructured_ingest.utils.dep_check import requires_dependencies

USER_AGENT = "@mixedbread-ai/unstructured"
Expand Down Expand Up @@ -85,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):

def get_exemplary_embedding(self) -> list[float]:
"""Get an exemplary embedding to determine dimensions and unit vector status."""
return self._embed(["Q"])[0]
return self.embed_query(query="Q")

@requires_dependencies(
["mixedbread_ai"],
Expand All @@ -100,59 +97,19 @@ def get_request_options(self) -> "RequestOptions":
additional_headers={"User-Agent": USER_AGENT},
)

def _embed(self, texts: list[str]) -> list[list[float]]:
"""
Embed a list of texts using the Mixedbread AI API.
Args:
texts (list[str]): List of texts to embed.
Returns:
list[list[float]]: List of embeddings.
"""

responses = []
client = self.config.get_client()
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
response = client.embeddings(
model=self.config.embedder_model_name,
normalized=True,
encoding_format=ENCODING_FORMAT,
truncation_strategy=TRUNCATION_STRATEGY,
request_options=self.get_request_options(),
input=batch,
)
responses.append(response)
return [item.embedding for response in responses for item in response.data]

def embed_documents(self, elements: list[dict]) -> list[dict]:
"""
Embed a list of document elements.
Args:
elements (list[Element]): List of document elements.
Returns:
list[Element]: Elements with embeddings.
"""
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = self._embed([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

def embed_query(self, query: str) -> list[float]:
"""
Embed a query string.
Args:
query (str): Query string to embed.
Returns:
list[float]: Embedding of the query.
"""
return self._embed([query])[0]
def get_client(self) -> "MixedbreadAI":
return self.config.get_client()

def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
response = client.embeddings(
model=self.config.embedder_model_name,
normalized=True,
encoding_format=ENCODING_FORMAT,
truncation_strategy=TRUNCATION_STRATEGY,
request_options=self.get_request_options(),
input=batch,
)
return [datum.embedding for datum in response.data]


@dataclass
Expand All @@ -162,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):

async def get_exemplary_embedding(self) -> list[float]:
"""Get an exemplary embedding to determine dimensions and unit vector status."""
embedding = await self._embed(["Q"])
return embedding[0]
return await self.embed_query(query="Q")

@requires_dependencies(
["mixedbread_ai"],
Expand All @@ -178,58 +134,16 @@ def get_request_options(self) -> "RequestOptions":
additional_headers={"User-Agent": USER_AGENT},
)

async def _embed(self, texts: list[str]) -> list[list[float]]:
"""
Embed a list of texts using the Mixedbread AI API.
Args:
texts (list[str]): List of texts to embed.
Returns:
list[list[float]]: List of embeddings.
"""
client = self.config.get_async_client()
tasks = []
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
tasks.append(
client.embeddings(
model=self.config.embedder_model_name,
normalized=True,
encoding_format=ENCODING_FORMAT,
truncation_strategy=TRUNCATION_STRATEGY,
request_options=self.get_request_options(),
input=batch,
)
)
responses = await asyncio.gather(*tasks)
return [item.embedding for response in responses for item in response.data]

async def embed_documents(self, elements: list[dict]) -> list[dict]:
"""
Embed a list of document elements.
Args:
elements (list[Element]): List of document elements.
Returns:
list[Element]: Elements with embeddings.
"""
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = await self._embed([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

async def embed_query(self, query: str) -> list[float]:
"""
Embed a query string.
Args:
query (str): Query string to embed.
Returns:
list[float]: Embedding of the query.
"""
embedding = await self._embed([query])
return embedding[0]
def get_client(self) -> "AsyncMixedbreadAI":
return self.config.get_async_client()

async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
response = await client.embeddings(
model=self.config.embedder_model_name,
normalized=True,
encoding_format=ENCODING_FORMAT,
truncation_strategy=TRUNCATION_STRATEGY,
request_options=self.get_request_options(),
input=batch,
)
return [datum.embedding for datum in response.data]
Loading

0 comments on commit 442dbfa

Please sign in to comment.