Skip to content

Commit

Permalink
bugfix/only embed content with text (#380)
Browse files Browse the repository at this point in the history
* only embed content with text

* make sure to use consistent key for embeddings metadata
  • Loading branch information
rbiseck3 authored Feb 11, 2025
1 parent 399854f commit 56e69a4
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 62 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## 0.5.2-dev0

### Enhancements

* **Only embed elements with text** - Only embed elements with text to avoid errors from embedders and optimize calls to APIs.

## 0.5.1

### Fixes
Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.1" # pragma: no cover
__version__ = "0.5.2-dev0" # pragma: no cover
19 changes: 13 additions & 6 deletions unstructured_ingest/embed/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import Field, SecretStr

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
Expand Down Expand Up @@ -145,9 +146,12 @@ def embed_query(self, query: str) -> list[float]:
return response_body.get("embedding")

def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = [self.embed_query(query=e.get("text", "")) for e in elements]
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements


@dataclass
Expand Down Expand Up @@ -186,8 +190,11 @@ async def embed_query(self, query: str) -> list[float]:
raise ValueError(f"Error raised by inference endpoint: {e}")

async def embed_documents(self, elements: list[dict]) -> list[dict]:
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = await asyncio.gather(
*[self.embed_query(query=e.get("text", "")) for e in elements]
*[self.embed_query(query=e.get("text", "")) for e in elements_with_text]
)
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements
15 changes: 11 additions & 4 deletions unstructured_ingest/embed/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from pydantic import Field

from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
BaseEmbeddingEncoder,
EmbeddingConfig,
)
from unstructured_ingest.utils.dep_check import requires_dependencies

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,6 +56,9 @@ def _embed_documents(self, texts: list[str]) -> list[list[float]]:
return embeddings.tolist()

def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = self._embed_documents([e.get("text", "") for e in elements])
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements
23 changes: 2 additions & 21 deletions unstructured_ingest/embed/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from pydantic import BaseModel, Field

EMBEDDINGS_KEY = "embeddings"


class EmbeddingConfig(BaseModel):
batch_size: Optional[int] = Field(
Expand All @@ -26,27 +28,6 @@ def wrap_error(self, e: Exception) -> Exception:
if possible"""
return e

@staticmethod
def _add_embeddings_to_elements(
elements: list[dict], embeddings: list[list[float]]
) -> list[dict]:
"""
Add embeddings to elements.
Args:
elements (list[Element]): List of elements.
embeddings (list[list[float]]): List of embeddings.
Returns:
list[Element]: Elements with embeddings added.
"""
assert len(elements) == len(embeddings)
elements_w_embedding = []
for i, element in enumerate(elements):
element["embeddings"] = embeddings[i]
elements_w_embedding.append(element)
return elements


@dataclass
class BaseEmbeddingEncoder(BaseEncoder, ABC):
Expand Down
17 changes: 13 additions & 4 deletions unstructured_ingest/embed/mixedbreadai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import Field, SecretStr

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
Expand Down Expand Up @@ -134,8 +135,12 @@ def embed_documents(self, elements: list[dict]) -> list[dict]:
Returns:
list[Element]: Elements with embeddings.
"""
embeddings = self._embed([e.get("text", "") for e in elements])
return self._add_embeddings_to_elements(elements, embeddings)
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = self._embed([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

def embed_query(self, query: str) -> list[float]:
"""
Expand Down Expand Up @@ -209,8 +214,12 @@ async def embed_documents(self, elements: list[dict]) -> list[dict]:
Returns:
list[Element]: Elements with embeddings.
"""
embeddings = await self._embed([e.get("text", "") for e in elements])
return self._add_embeddings_to_elements(elements, embeddings)
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = await self._embed([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

async def embed_query(self, query: str) -> list[float]:
"""
Expand Down
19 changes: 13 additions & 6 deletions unstructured_ingest/embed/octoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import Field, SecretStr

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
Expand Down Expand Up @@ -89,7 +90,9 @@ def embed_query(self, query: str):
return response.data[0].embedding

def embed_documents(self, elements: list[dict]) -> list[dict]:
texts = [e.get("text", "") for e in elements]
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
texts = [e["text"] for e in elements_with_text]
embeddings = []
client = self.config.get_client()
try:
Expand All @@ -100,8 +103,9 @@ def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings.extend([data.embedding for data in response.data])
except Exception as e:
raise self.wrap_error(e=e)
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements


@dataclass
Expand All @@ -122,7 +126,9 @@ async def embed_query(self, query: str):
return response.data[0].embedding

async def embed_documents(self, elements: list[dict]) -> list[dict]:
texts = [e.get("text", "") for e in elements]
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
texts = [e["text"] for e in elements_with_text]
client = self.config.get_async_client()
embeddings = []
try:
Expand All @@ -133,5 +139,6 @@ async def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings.extend([data.embedding for data in response.data])
except Exception as e:
raise self.wrap_error(e=e)
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements
19 changes: 13 additions & 6 deletions unstructured_ingest/embed/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import Field, SecretStr

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
Expand Down Expand Up @@ -82,7 +83,9 @@ def embed_query(self, query: str) -> list[float]:

def embed_documents(self, elements: list[dict]) -> list[dict]:
client = self.config.get_client()
texts = [e.get("text", "") for e in elements]
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
texts = [e["text"] for e in elements_with_text]
embeddings = []
try:
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
Expand All @@ -92,8 +95,9 @@ def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings.extend([data.embedding for data in response.data])
except Exception as e:
raise self.wrap_error(e=e)
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements


@dataclass
Expand All @@ -115,7 +119,9 @@ async def embed_query(self, query: str) -> list[float]:

async def embed_documents(self, elements: list[dict]) -> list[dict]:
client = self.config.get_async_client()
texts = [e.get("text", "") for e in elements]
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
texts = [e["text"] for e in elements_with_text]
embeddings = []
try:
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
Expand All @@ -125,5 +131,6 @@ async def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings.extend([data.embedding for data in response.data])
except Exception as e:
raise self.wrap_error(e=e)
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements
17 changes: 13 additions & 4 deletions unstructured_ingest/embed/togetherai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import Field, SecretStr

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
Expand Down Expand Up @@ -67,8 +68,12 @@ def embed_query(self, query: str) -> list[float]:
return self._embed_documents(elements=[query])[0]

def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = self._embed_documents([e.get("text", "") for e in elements])
return self._add_embeddings_to_elements(elements, embeddings)
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

def _embed_documents(self, elements: list[str]) -> list[list[float]]:
client = self.config.get_client()
Expand Down Expand Up @@ -98,8 +103,12 @@ async def embed_query(self, query: str) -> list[float]:
return embedding[0]

async def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = await self._embed_documents([e.get("text", "") for e in elements])
return self._add_embeddings_to_elements(elements, embeddings)
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
client = self.config.get_async_client()
Expand Down
19 changes: 13 additions & 6 deletions unstructured_ingest/embed/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic.functional_validators import BeforeValidator

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
Expand Down Expand Up @@ -75,9 +76,12 @@ def embed_query(self, query):
return self._embed_documents(elements=[query])[0]

def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = self._embed_documents([e.get("text", "") for e in elements])
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

@requires_dependencies(
["vertexai"],
Expand Down Expand Up @@ -110,9 +114,12 @@ async def embed_query(self, query):
return embedding[0]

async def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = await self._embed_documents([e.get("text", "") for e in elements])
elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
return elements_with_embeddings
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

@requires_dependencies(
["vertexai"],
Expand Down
17 changes: 13 additions & 4 deletions unstructured_ingest/embed/voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import Field, SecretStr

from unstructured_ingest.embed.interfaces import (
EMBEDDINGS_KEY,
AsyncBaseEmbeddingEncoder,
BaseEmbeddingEncoder,
EmbeddingConfig,
Expand Down Expand Up @@ -107,8 +108,12 @@ def _embed_documents(self, elements: list[str]) -> list[list[float]]:
return embeddings

def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = self._embed_documents([e.get("text", "") for e in elements])
return self._add_embeddings_to_elements(elements, embeddings)
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

def embed_query(self, query: str) -> list[float]:
return self._embed_documents(elements=[query])[0]
Expand All @@ -135,8 +140,12 @@ async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
return embeddings

async def embed_documents(self, elements: list[dict]) -> list[dict]:
embeddings = await self._embed_documents([e.get("text", "") for e in elements])
return self._add_embeddings_to_elements(elements, embeddings)
elements = elements.copy()
elements_with_text = [e for e in elements if e.get("text")]
embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
for element, embedding in zip(elements_with_text, embeddings):
element[EMBEDDINGS_KEY] = embedding
return elements

async def embed_query(self, query: str) -> list[float]:
embedding = await self._embed_documents(elements=[query])
Expand Down

0 comments on commit 56e69a4

Please sign in to comment.