Skip to content

Commit

Permalink
Fix embed endpoint (#38)
Browse files Browse the repository at this point in the history
* Remove unsupported embedding models

* Add methods to make requests using items instead of texts

* Fix embed method and clarify variable names

* Add TODO and fix tests

* Add tests for embed endpoints

* Function name and docstring changes and add logging for mismatch lists

* Remove TODO

* Add guarding for embed_and_join_metadata_by_columns

* Add object-based embedding method

* Add some notes

* Clean up docstrings

* Fix bugs and add tests
  • Loading branch information
FullMetalMeowchemist authored Sep 18, 2023
1 parent 34beb92 commit 2e9cf0a
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 12 deletions.
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"requests~=2.31",
"validators~=0.21",
"pandas~=2.0",
"more-itertools~=10.1",
]

[project.optional-dependencies]
Expand Down
130 changes: 120 additions & 10 deletions python/starpoint/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import more_itertools
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Hashable, List, Optional
from uuid import UUID

import requests
Expand All @@ -22,11 +23,15 @@

# Error and warning messages
SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid."
TEXT_METADATA_LENGTH_MISMATCH_WARNING = (
"The length of the texts and metadatas provided are different. There may be a mismatch "
"between texts and the metadatas length; this may cause undesired results between the joining of "
"embeddings and metadatas."
)


class EmbeddingModel(Enum):
MINI6 = "MINI6"
MINI12 = "MINI12"
MINILM = "MiniLm"


class EmbeddingClient(object):
Expand All @@ -41,23 +46,128 @@ def __init__(self, api_key: UUID, host: Optional[str] = None):

def embed(
self,
text: List[str],
texts: List[str],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some texts creates embeddings using a model in starpoint. This is a
version of `embed_and_join_metadata_by_column` where joining metadata with the result is
not necessary. The same API is used for the two methods.
Args:
texts: List of strings to create embeddings from.
model: An enum choice from EmbeddingModel.
Returns:
dict: Result with list of texts, metadata, and embeddings.
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
text_embedding_items = [{"text": text, "metadata": None} for text in texts]
return self.embed_items(text_embedding_items=text_embedding_items, model=model)

def embed_and_join_metadata_by_columns(
self,
texts: List[str],
metadatas: List[Dict],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_and_join_metadata` or
`embed_items` instead, as mismatched `texts` and `metadatas` will output undesirable results.
Under the hood this is using `embed_items`.
Args:
texts: List of strings to create embeddings from.
metadatas: List of metadata to join to the string and embedding when the embedding operation is complete.
This metadata makes your embeddings queryable within starpoint.
model: An enum choice from EmbeddingModel.
Returns:
dict: Result with list of texts, metadata, and embeddings.
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
if not isinstance(texts, list):
raise ValueError("texts passed was not of type list")
if not isinstance(metadatas, list):
raise ValueError("metadatas passed was not of type list")

if len(texts) != len(metadatas):
LOGGER.warning(TEXT_METADATA_LENGTH_MISMATCH_WARNING)

text_embedding_items = [
{
"text": text,
"metadata": metadata,
}
for text, metadata in zip(texts, metadatas)
]
return self.embed_items(text_embedding_items=text_embedding_items, model=model)

def embed_and_join_metadata(
self,
text_embedding_items: List[Dict],
embedding_key: Hashable,
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some text and creates an embedding against a model in starpoint.
"""Takes some texts and creates embeddings using a model in starpoint, and joins them to
all additional data as metadata. Under the hood this is using `embed_and_join_metadata_by_columns`
which is using `embed_items`.
Args:
text: List of strings to create embeddings from.
model: A choice of
text_embedding_items: List of dicts of data to create embeddings from.
embedding_key: the key in each item used to create embeddings from.
e.g. `"context"` would be passed if each item looks like this: `{"context": "embed this text"}`
model: An enum choice from EmbeddingModel.
Returns:
dict: Result with multiple lists of embeddings, matching the number of requested strings to
create embeddings from.
dict: Result with list of texts, metadata, and embeddings.
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
request_data = dict(text=text, model=model.value)
if not text_embedding_items:
raise ValueError("text_embedding_items received an empty list.")

texts = list(map(lambda item: item.get(embedding_key), text_embedding_items))
if not all(texts):
unqualified_indices = list(
more_itertools.locate(texts, lambda x: x is None)
)
raise ValueError(
"The following indices had items that did not have the "
f"{embedding_key}:\n {unqualified_indices}"
)

# We can also do this operation in the first map that creates texts, but that might make additional operations
# in here a lot more annoying. It's an optimization that shouldn't happen right now.
list(map(lambda item: item.pop(embedding_key), text_embedding_items))

return self.embed_and_join_metadata_by_columns(
texts=texts, metadatas=text_embedding_items, model=model
)

def embed_items(
self,
text_embedding_items: List[Dict],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes items with text and metadata, and embeds the text using a model in starpoint. Metadata is joined with
the results for ergonomics.
Args:
text_embedding_items: List of dict where the text and metadata are paired together
model: An enum choice from EmbeddingModel.
Returns:
dict: Result with list of texts, metadata, and embeddings.
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""

request_data = dict(items=text_embedding_items, model=model.value)
try:
response = requests.post(
url=f"{self.host}{EMBED_PATH}",
Expand Down
202 changes: 200 additions & 2 deletions python/tests/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from unittest.mock import MagicMock, patch
from uuid import UUID, uuid4

Expand Down Expand Up @@ -55,7 +56,7 @@ def test_embedding_embed_not_200(
logger_mock = MagicMock()
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6)
actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM)

requests_mock.post.assert_called()
logger_mock.error.assert_called_once()
Expand All @@ -75,6 +76,203 @@ def test_embedding_embed_SSLError(
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

with pytest.raises(SSLError, match="mock exception"):
mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6)
mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM)

logger_mock.error.assert_called_once_with(embedding.SSL_ERROR_MSG)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "asdf"
input_model = embedding.EmbeddingModel.MINILM
expected_item = [{"text": input_text, "metadata": None}]

actual_json = mock_embedding_client.embed([input_text], input_model)

embed_items_mock.assert_called_once_with(
text_embedding_items=expected_item, model=input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "asdf"
input_metadata = {"label": "asdf"}
input_model = embedding.EmbeddingModel.MINILM
expected_item = [{"text": input_text, "metadata": input_metadata}]

actual_json = mock_embedding_client.embed_and_join_metadata_by_columns(
[input_text], [input_metadata], input_model
)

embed_items_mock.assert_called_once_with(
text_embedding_items=expected_item, model=input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns_non_list_texts(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_metadata = {"label": "asdf"}
input_model = embedding.EmbeddingModel.MINILM

with pytest.raises(ValueError):
mock_embedding_client.embed_and_join_metadata_by_columns(
"not_list_texts", [input_metadata], input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns_non_list_metadatas(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "asdf"
input_model = embedding.EmbeddingModel.MINILM

with pytest.raises(ValueError):
mock_embedding_client.embed_and_join_metadata_by_columns(
[input_text], {"label": "not_list_metadatas"}, input_model
)


@pytest.mark.parametrize(
"input_text,input_metadata",
[
[
["embed_text1", "embed_text2"],
[{"label": "label1"}],
],
[
["embed_text1"],
[{"label": "label1"}, {"label": "label2"}],
],
],
)
@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns_mismatch_list(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
input_text: List,
input_metadata: List,
monkeypatch: MonkeyPatch,
):
input_model = embedding.EmbeddingModel.MINILM

logger_mock = MagicMock()
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

actual_json = mock_embedding_client.embed_and_join_metadata_by_columns(
input_text, input_metadata, input_model
)

logger_mock.warning.assert_called_once_with(
embedding.TEXT_METADATA_LENGTH_MISMATCH_WARNING
)


@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata(
requests_mock: MagicMock,
embed_and_join_metadata_by_columns_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "embed text"
input_metadata = {
"metadata1": "metadata1",
"metadata2": "metadata2",
}
embed_key = "text"
input_dict = {embed_key: input_text}
input_dict.update(input_metadata)
test_embedding_items = [input_dict]

expected_item = [{"text": input_text, "metadata": input_metadata}]
input_model = embedding.EmbeddingModel.MINILM

actual_json = mock_embedding_client.embed_and_join_metadata(
test_embedding_items, embed_key, input_model
)

embed_and_join_metadata_by_columns_mock.assert_called_once_with(
texts=[input_text], metadatas=[input_metadata], model=input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_no_embed_key(
requests_mock: MagicMock,
embed_and_join_metadata_by_columns_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "embed text"
input_metadata = {
"metadata1": "metadata1",
"metadata2": "metadata2",
}
input_dict = {"text": input_text}
input_dict.update(input_metadata)
test_embedding_items = [input_dict]

input_model = embedding.EmbeddingModel.MINILM
embed_key = "no key"

with pytest.raises(ValueError):
actual_json = mock_embedding_client.embed_and_join_metadata(
test_embedding_items, embed_key, input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_no_values(
requests_mock: MagicMock,
embed_and_join_metadata_by_columns_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_model = embedding.EmbeddingModel.MINILM
with pytest.raises(ValueError):
actual_json = mock_embedding_client.embed_and_join_metadata(
[], "text", input_model
)


@patch("starpoint.embedding.requests")
def test_embedding_embed_items(
requests_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
requests_mock.post().ok = True
test_value = {"mock_return": "value"}
requests_mock.post().json.return_value = test_value

expected_json = {}

actual_json = mock_embedding_client.embed_items(
[{"text": "asdf", "metadata": {"label": "asdf"}}],
embedding.EmbeddingModel.MINILM,
)

requests_mock.post.assert_called()
requests_mock.post().json.assert_called()
assert actual_json == test_value

0 comments on commit 2e9cf0a

Please sign in to comment.