Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix embed endpoint #38

Merged
merged 15 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 74 additions & 9 deletions python/starpoint/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@

# Error and warning messages
SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid."
TEXT_METADATA_LENGTH_MISMATCH_WARNING = (
"The length of the texts and metadatas provided are different. There may be a mismatch "
"between texts and the metadatas length; this may cause undesired results between the joining of "
"embeddings and metadatas."
)


class EmbeddingModel(Enum):
MINI6 = "MINI6"
MINI12 = "MINI12"
MINILM = "MiniLm"


class EmbeddingClient(object):
Expand All @@ -41,23 +45,84 @@ def __init__(self, api_key: UUID, host: Optional[str] = None):

def embed(
self,
text: List[str],
texts: List[str],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some texts creates embeddings using a model in starpoint. This is a
version of `embed_and_join_metadata_by_column` where joining metadata with the result is
not necessary. The same API is used for the two methods.

Args:
texts: List of strings to create embeddings from.
model: An enum choice from EmbeddingModel.

Returns:
dict: Result with list of texts, metadata, and embeddings.

Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
text_embedding_items = [{"text": text, "metadata": None} for text in texts]
return self.embed_items(text_embedding_items=text_embedding_items, model=model)

def embed_and_join_metadata_by_columns(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add one more method where you pass in a field along with a list of dicts like we discussed?

self,
texts: List[str],
metadatas: List[Dict],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items`
instead, as mismatched `texts` and `metadatas` will output undesirable results.
Under the hood this is using `embed_items`.

Args:
texts: List of strings to create embeddings from.
metadatas: List of metadata to join to the string and embedding when the embedding operation is complete.
This metadata makes your embeddings queryable within starpoint.
model: An enum choice from EmbeddingModel.

Returns:
dict: Result with list of texts, metadata, and embeddings.

Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
if not isinstance(texts, list):
raise ValueError("texts passed was not of type list")
if not isinstance(metadatas, list):
raise ValueError("metadatas passed was not of type list")

if len(texts) != len(metadatas):
LOGGER.warning(TEXT_METADATA_LENGTH_MISMATCH_WARNING)

text_embedding_items = [
{
"text": text,
"metadata": metadata,
}
for text, metadata in zip(texts, metadatas)
]
return self.embed_items(text_embedding_items=text_embedding_items, model=model)

def embed_items(
self,
text_embedding_items: List[Dict],
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some text and creates an embedding against a model in starpoint.
"""Takes items with text and metadata, and embeds the text using a model in starpoint. Metadata is joined with
the results for ergonomics.

Args:
text: List of strings to create embeddings from.
model: A choice of
text_embedding_items: List of dict where the text and metadata are paired together
model: An enum choice from EmbeddingModel.

Returns:
dict: Result with multiple lists of embeddings, matching the number of requested strings to
create embeddings from.
dict: Result with list of texts, metadata, and embeddings.

Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
request_data = dict(text=text, model=model.value)
request_data = dict(items=text_embedding_items, model=model.value)
try:
response = requests.post(
url=f"{self.host}{EMBED_PATH}",
Expand Down
134 changes: 132 additions & 2 deletions python/tests/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from unittest.mock import MagicMock, patch
from uuid import UUID, uuid4

Expand Down Expand Up @@ -55,7 +56,7 @@ def test_embedding_embed_not_200(
logger_mock = MagicMock()
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6)
actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM)

requests_mock.post.assert_called()
logger_mock.error.assert_called_once()
Expand All @@ -75,6 +76,135 @@ def test_embedding_embed_SSLError(
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

with pytest.raises(SSLError, match="mock exception"):
mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6)
mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM)

logger_mock.error.assert_called_once_with(embedding.SSL_ERROR_MSG)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "asdf"
input_model = embedding.EmbeddingModel.MINILM
expected_item = [{"text": input_text, "metadata": None}]

actual_json = mock_embedding_client.embed([input_text], input_model)

embed_items_mock.assert_called_once_with(
text_embedding_items=expected_item, model=input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "asdf"
input_metadata = {"label": "asdf"}
input_model = embedding.EmbeddingModel.MINILM
expected_item = [{"text": input_text, "metadata": input_metadata}]

actual_json = mock_embedding_client.embed_and_join_metadata_by_columns(
[input_text], [input_metadata], input_model
)

embed_items_mock.assert_called_once_with(
text_embedding_items=expected_item, model=input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns_non_list_texts(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_metadata = {"label": "asdf"}
input_model = embedding.EmbeddingModel.MINILM

with pytest.raises(ValueError):
mock_embedding_client.embed_and_join_metadata_by_columns(
"not_list_texts", [input_metadata], input_model
)


@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns_non_list_metadatas(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
input_text = "asdf"
input_model = embedding.EmbeddingModel.MINILM

with pytest.raises(ValueError):
mock_embedding_client.embed_and_join_metadata_by_columns(
[input_text], {"label": "not_list_metadatas"}, input_model
)


@pytest.mark.parametrize(
"input_text,input_metadata",
[
[
["embed_text1", "embed_text2"],
[{"label": "label1"}],
],
[
["embed_text1"],
[{"label": "label1"}, {"label": "label2"}],
],
],
)
@patch("starpoint.embedding.EmbeddingClient.embed_items")
@patch("starpoint.embedding.requests")
def test_embedding_embed_and_join_metadata_by_columns_mismatch_list(
requests_mock: MagicMock,
embed_items_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
input_text: List,
input_metadata: List,
monkeypatch: MonkeyPatch,
):
input_model = embedding.EmbeddingModel.MINILM

logger_mock = MagicMock()
monkeypatch.setattr(embedding, "LOGGER", logger_mock)

actual_json = mock_embedding_client.embed_and_join_metadata_by_columns(
input_text, input_metadata, input_model
)

logger_mock.warning.assert_called_once_with(
embedding.TEXT_METADATA_LENGTH_MISMATCH_WARNING
)


@patch("starpoint.embedding.requests")
def test_embedding_embed_items(
requests_mock: MagicMock,
mock_embedding_client: embedding.EmbeddingClient,
):
requests_mock.post().ok = True
test_value = {"mock_return": "value"}
requests_mock.post().json.return_value = test_value

expected_json = {}

actual_json = mock_embedding_client.embed_items(
[{"text": "asdf", "metadata": {"label": "asdf"}}],
embedding.EmbeddingModel.MINILM,
)

requests_mock.post.assert_called()
requests_mock.post().json.assert_called()
assert actual_json == test_value