From 2e9cf0a0de51fa31cac07334b3b52df92e606c21 Mon Sep 17 00:00:00 2001 From: FullMetalMeowchemist <117529599+FullMetalMeowchemist@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:00:25 -0500 Subject: [PATCH] Fix embed endpoint (#38) * Remove unsupported embedding models * Add methods to make requests using items instead of texts * Fix embed method and clarify variable names * Add TODO and fix tests * Add tests for embed endpoints * Function name and docstring changes and add logging for mismatch lists * Remove TODO * Add guarding for embed_and_join_metadata_by_columns * Add object-based embedding method * Add some notes * Clean up docstrings * Fix bugs and add tests --- python/pyproject.toml | 1 + python/starpoint/embedding.py | 130 +++++++++++++++++++-- python/tests/test_embedding.py | 202 ++++++++++++++++++++++++++++++++- 3 files changed, 321 insertions(+), 12 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6d76adc..15540bc 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "requests~=2.31", "validators~=0.21", "pandas~=2.0", + "more-itertools~=10.1", ] [project.optional-dependencies] diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index c488a99..83b0582 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -1,6 +1,7 @@ import logging +import more_itertools from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Hashable, List, Optional from uuid import UUID import requests @@ -22,11 +23,15 @@ # Error and warning messages SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid." +TEXT_METADATA_LENGTH_MISMATCH_WARNING = ( + "The length of the texts and metadatas provided are different. There may be a mismatch " + "between texts and the metadatas length; this may cause undesired results between the joining of " + "embeddings and metadatas." +) class EmbeddingModel(Enum): - MINI6 = "MINI6" - MINI12 = "MINI12" + MINILM = "MiniLm" class EmbeddingClient(object): @@ -41,23 +46,128 @@ def __init__(self, api_key: UUID, host: Optional[str] = None): def embed( self, - text: List[str], + texts: List[str], + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes some texts creates embeddings using a model in starpoint. This is a + version of `embed_and_join_metadata_by_column` where joining metadata with the result is + not necessary. The same API is used for the two methods. + + Args: + texts: List of strings to create embeddings from. + model: An enum choice from EmbeddingModel. + + Returns: + dict: Result with list of texts, metadata, and embeddings. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + text_embedding_items = [{"text": text, "metadata": None} for text in texts] + return self.embed_items(text_embedding_items=text_embedding_items, model=model) + + def embed_and_join_metadata_by_columns( + self, + texts: List[str], + metadatas: List[Dict], + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_and_join_metadata` or + `embed_items` instead, as mismatched `texts` and `metadatas` will output undesirable results. + Under the hood this is using `embed_items`. + + Args: + texts: List of strings to create embeddings from. + metadatas: List of metadata to join to the string and embedding when the embedding operation is complete. + This metadata makes your embeddings queryable within starpoint. + model: An enum choice from EmbeddingModel. + + Returns: + dict: Result with list of texts, metadata, and embeddings. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + if not isinstance(texts, list): + raise ValueError("texts passed was not of type list") + if not isinstance(metadatas, list): + raise ValueError("metadatas passed was not of type list") + + if len(texts) != len(metadatas): + LOGGER.warning(TEXT_METADATA_LENGTH_MISMATCH_WARNING) + + text_embedding_items = [ + { + "text": text, + "metadata": metadata, + } + for text, metadata in zip(texts, metadatas) + ] + return self.embed_items(text_embedding_items=text_embedding_items, model=model) + + def embed_and_join_metadata( + self, + text_embedding_items: List[Dict], + embedding_key: Hashable, model: EmbeddingModel, ) -> Dict[str, List[Dict]]: - """Takes some text and creates an embedding against a model in starpoint. + """Takes some texts and creates embeddings using a model in starpoint, and joins them to + all additional data as metadata. Under the hood this is using `embed_and_join_metadata_by_columns` + which is using `embed_items`. Args: - text: List of strings to create embeddings from. - model: A choice of + text_embedding_items: List of dicts of data to create embeddings from. + embedding_key: the key in each item used to create embeddings from. + e.g. `"context"` would be passed if each item looks like this: `{"context": "embed this text"}` + model: An enum choice from EmbeddingModel. Returns: - dict: Result with multiple lists of embeddings, matching the number of requested strings to - create embeddings from. + dict: Result with list of texts, metadata, and embeddings. Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ - request_data = dict(text=text, model=model.value) + if not text_embedding_items: + raise ValueError("text_embedding_items received an empty list.") + + texts = list(map(lambda item: item.get(embedding_key), text_embedding_items)) + if not all(texts): + unqualified_indices = list( + more_itertools.locate(texts, lambda x: x is None) + ) + raise ValueError( + "The following indices had items that did not have the " + f"{embedding_key}:\n {unqualified_indices}" + ) + + # We can also do this operation in the first map that creates texts, but that might make additional operations + # in here a lot more annoying. It's an optimization that shouldn't happen right now. + list(map(lambda item: item.pop(embedding_key), text_embedding_items)) + + return self.embed_and_join_metadata_by_columns( + texts=texts, metadatas=text_embedding_items, model=model + ) + + def embed_items( + self, + text_embedding_items: List[Dict], + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes items with text and metadata, and embeds the text using a model in starpoint. Metadata is joined with + the results for ergonomics. + + Args: + text_embedding_items: List of dict where the text and metadata are paired together + model: An enum choice from EmbeddingModel. + + Returns: + dict: Result with list of texts, metadata, and embeddings. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + + request_data = dict(items=text_embedding_items, model=model.value) try: response = requests.post( url=f"{self.host}{EMBED_PATH}", diff --git a/python/tests/test_embedding.py b/python/tests/test_embedding.py index cd0f365..9e7887b 100644 --- a/python/tests/test_embedding.py +++ b/python/tests/test_embedding.py @@ -1,3 +1,4 @@ +from typing import List from unittest.mock import MagicMock, patch from uuid import UUID, uuid4 @@ -55,7 +56,7 @@ def test_embedding_embed_not_200( logger_mock = MagicMock() monkeypatch.setattr(embedding, "LOGGER", logger_mock) - actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6) + actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM) requests_mock.post.assert_called() logger_mock.error.assert_called_once() @@ -75,6 +76,203 @@ def test_embedding_embed_SSLError( monkeypatch.setattr(embedding, "LOGGER", logger_mock) with pytest.raises(SSLError, match="mock exception"): - mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6) + mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM) logger_mock.error.assert_called_once_with(embedding.SSL_ERROR_MSG) + + +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "asdf" + input_model = embedding.EmbeddingModel.MINILM + expected_item = [{"text": input_text, "metadata": None}] + + actual_json = mock_embedding_client.embed([input_text], input_model) + + embed_items_mock.assert_called_once_with( + text_embedding_items=expected_item, model=input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_by_columns( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "asdf" + input_metadata = {"label": "asdf"} + input_model = embedding.EmbeddingModel.MINILM + expected_item = [{"text": input_text, "metadata": input_metadata}] + + actual_json = mock_embedding_client.embed_and_join_metadata_by_columns( + [input_text], [input_metadata], input_model + ) + + embed_items_mock.assert_called_once_with( + text_embedding_items=expected_item, model=input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_by_columns_non_list_texts( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_metadata = {"label": "asdf"} + input_model = embedding.EmbeddingModel.MINILM + + with pytest.raises(ValueError): + mock_embedding_client.embed_and_join_metadata_by_columns( + "not_list_texts", [input_metadata], input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_by_columns_non_list_metadatas( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "asdf" + input_model = embedding.EmbeddingModel.MINILM + + with pytest.raises(ValueError): + mock_embedding_client.embed_and_join_metadata_by_columns( + [input_text], {"label": "not_list_metadatas"}, input_model + ) + + +@pytest.mark.parametrize( + "input_text,input_metadata", + [ + [ + ["embed_text1", "embed_text2"], + [{"label": "label1"}], + ], + [ + ["embed_text1"], + [{"label": "label1"}, {"label": "label2"}], + ], + ], +) +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_by_columns_mismatch_list( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, + input_text: List, + input_metadata: List, + monkeypatch: MonkeyPatch, +): + input_model = embedding.EmbeddingModel.MINILM + + logger_mock = MagicMock() + monkeypatch.setattr(embedding, "LOGGER", logger_mock) + + actual_json = mock_embedding_client.embed_and_join_metadata_by_columns( + input_text, input_metadata, input_model + ) + + logger_mock.warning.assert_called_once_with( + embedding.TEXT_METADATA_LENGTH_MISMATCH_WARNING + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata( + requests_mock: MagicMock, + embed_and_join_metadata_by_columns_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "embed text" + input_metadata = { + "metadata1": "metadata1", + "metadata2": "metadata2", + } + embed_key = "text" + input_dict = {embed_key: input_text} + input_dict.update(input_metadata) + test_embedding_items = [input_dict] + + expected_item = [{"text": input_text, "metadata": input_metadata}] + input_model = embedding.EmbeddingModel.MINILM + + actual_json = mock_embedding_client.embed_and_join_metadata( + test_embedding_items, embed_key, input_model + ) + + embed_and_join_metadata_by_columns_mock.assert_called_once_with( + texts=[input_text], metadatas=[input_metadata], model=input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_no_embed_key( + requests_mock: MagicMock, + embed_and_join_metadata_by_columns_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "embed text" + input_metadata = { + "metadata1": "metadata1", + "metadata2": "metadata2", + } + input_dict = {"text": input_text} + input_dict.update(input_metadata) + test_embedding_items = [input_dict] + + input_model = embedding.EmbeddingModel.MINILM + embed_key = "no key" + + with pytest.raises(ValueError): + actual_json = mock_embedding_client.embed_and_join_metadata( + test_embedding_items, embed_key, input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_no_values( + requests_mock: MagicMock, + embed_and_join_metadata_by_columns_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_model = embedding.EmbeddingModel.MINILM + with pytest.raises(ValueError): + actual_json = mock_embedding_client.embed_and_join_metadata( + [], "text", input_model + ) + + +@patch("starpoint.embedding.requests") +def test_embedding_embed_items( + requests_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + requests_mock.post().ok = True + test_value = {"mock_return": "value"} + requests_mock.post().json.return_value = test_value + + expected_json = {} + + actual_json = mock_embedding_client.embed_items( + [{"text": "asdf", "metadata": {"label": "asdf"}}], + embedding.EmbeddingModel.MINILM, + ) + + requests_mock.post.assert_called() + requests_mock.post().json.assert_called() + assert actual_json == test_value