From 50a246a4f7f4a6957743751f28a9d7f386013486 Mon Sep 17 00:00:00 2001 From: Zapple Date: Wed, 13 Sep 2023 13:55:16 -0500 Subject: [PATCH 01/12] Remove unsupported embedding models --- python/starpoint/embedding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index c488a99..2db4cf1 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -25,8 +25,7 @@ class EmbeddingModel(Enum): - MINI6 = "MINI6" - MINI12 = "MINI12" + MINILM = "MiniLm" class EmbeddingClient(object): From 2ec24fc015f3ced24559037ac4f47d01dfcaf45b Mon Sep 17 00:00:00 2001 From: Zapple Date: Wed, 13 Sep 2023 16:25:19 -0500 Subject: [PATCH 02/12] Add methods to make requests using items instead of texts --- python/starpoint/embedding.py | 83 ++++++++++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 6 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index 2db4cf1..f84229f 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -38,20 +38,22 @@ def __init__(self, api_key: UUID, host: Optional[str] = None): self.host = _validate_host(host) self.api_key = api_key + # FIXME: auto-fill the metadata field def embed( self, - text: List[str], + texts: List[str], model: EmbeddingModel, ) -> Dict[str, List[Dict]]: - """Takes some text and creates an embedding against a model in starpoint. + """Takes some texts creates embeddings using a model in starpoint. This is a + version of embed_and_join_metadata where joining metadata with the result is + not necessary. The same API is used between the two methods. Args: - text: List of strings to create embeddings from. - model: A choice of + texts: List of strings to create embeddings from. + model: An enum choice from EmbeddingModel. Returns: - dict: Result with multiple lists of embeddings, matching the number of requested strings to - create embeddings from. + dict: Result with list of texts, metadata, and embeddings. Raises: requests.exceptions.SSLError: Failure likely due to network issues. @@ -77,3 +79,72 @@ def embed( ) return {} return response.json() + + def embed_and_join_metadata( + self, + texts: List[str], + metadatas: List[Dict], + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes some texts and creates embeddings using a model in starpoint. Metadata is joined with + the results for ergonomics. Under the hood this is using embed_items. + + Args: + texts: List of strings to create embeddings from. + metadatas: List of metadata to join to the string and embedding when the embedding operation is complete. + model: An enum choice from EmbeddingModel. + + Returns: + dict: Result with list of texts, metadata, and embeddings. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + items = [ + { + "text": text, + "metadata": metadata, + } + for text, metadata in zip(texts, metadatas) + ] + return self.embed_items(items=items, model=model) + + def embed_items( + self, + items: List[Dict], + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes items with text and metadata, and embeds the text using a model in starpoint. Metadata is joined with + the results for ergonomics. + + Args: + items: List of dict where the text and metadata are paired together + model: An enum choice from EmbeddingModel. + + Returns: + dict: Result with list of texts, metadata, and embeddings. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + request_data = dict(items=items, model=model.value) + try: + response = requests.post( + url=f"{self.host}{EMBED_PATH}", + json=request_data, + headers=_build_header( + api_key=self.api_key, + additional_headers={"Content-Type": "application/json"}, + ), + ) + except requests.exceptions.SSLError as e: + LOGGER.error(SSL_ERROR_MSG) + raise + + if not response.ok: + LOGGER.error( + f"Request failed with status code {response.status_code} " + f"and the following message:\n{response.text}" + ) + return {} + return response.json() From 52b4d6612dfc36e02f7f137e0cd6f5e51f444be7 Mon Sep 17 00:00:00 2001 From: Zapple Date: Wed, 13 Sep 2023 16:39:43 -0500 Subject: [PATCH 03/12] Fix embed method and clarify variable names --- python/starpoint/embedding.py | 35 ++++++++--------------------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index f84229f..7fba4ba 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -38,7 +38,6 @@ def __init__(self, api_key: UUID, host: Optional[str] = None): self.host = _validate_host(host) self.api_key = api_key - # FIXME: auto-fill the metadata field def embed( self, texts: List[str], @@ -58,27 +57,8 @@ def embed( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ - request_data = dict(text=text, model=model.value) - try: - response = requests.post( - url=f"{self.host}{EMBED_PATH}", - json=request_data, - headers=_build_header( - api_key=self.api_key, - additional_headers={"Content-Type": "application/json"}, - ), - ) - except requests.exceptions.SSLError as e: - LOGGER.error(SSL_ERROR_MSG) - raise - - if not response.ok: - LOGGER.error( - f"Request failed with status code {response.status_code} " - f"and the following message:\n{response.text}" - ) - return {} - return response.json() + text_embedding_items = [{"text": text, "metadata": None} for text in texts] + return self.embed_items(text_embedding_items=text_embedding_items, model=model) def embed_and_join_metadata( self, @@ -92,6 +72,7 @@ def embed_and_join_metadata( Args: texts: List of strings to create embeddings from. metadatas: List of metadata to join to the string and embedding when the embedding operation is complete. + This metadata makes your embeddings queryable within starpoint. model: An enum choice from EmbeddingModel. Returns: @@ -100,25 +81,25 @@ def embed_and_join_metadata( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ - items = [ + text_embedding_items = [ { "text": text, "metadata": metadata, } for text, metadata in zip(texts, metadatas) ] - return self.embed_items(items=items, model=model) + return self.embed_items(text_embedding_items=text_embedding_items, model=model) def embed_items( self, - items: List[Dict], + text_embedding_items: List[Dict], model: EmbeddingModel, ) -> Dict[str, List[Dict]]: """Takes items with text and metadata, and embeds the text using a model in starpoint. Metadata is joined with the results for ergonomics. Args: - items: List of dict where the text and metadata are paired together + text_embedding_items: List of dict where the text and metadata are paired together model: An enum choice from EmbeddingModel. Returns: @@ -127,7 +108,7 @@ def embed_items( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ - request_data = dict(items=items, model=model.value) + request_data = dict(items=text_embedding_items, model=model.value) try: response = requests.post( url=f"{self.host}{EMBED_PATH}", From 87bc4900c116583f1d47685ede04c804f499883e Mon Sep 17 00:00:00 2001 From: Zapple Date: Wed, 13 Sep 2023 16:52:54 -0500 Subject: [PATCH 04/12] Add TODO and fix tests --- python/starpoint/embedding.py | 1 + python/tests/test_embedding.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index 7fba4ba..de3db39 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -81,6 +81,7 @@ def embed_and_join_metadata( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ + # TODO: add len + logging here text_embedding_items = [ { "text": text, diff --git a/python/tests/test_embedding.py b/python/tests/test_embedding.py index cd0f365..0f09690 100644 --- a/python/tests/test_embedding.py +++ b/python/tests/test_embedding.py @@ -55,7 +55,7 @@ def test_embedding_embed_not_200( logger_mock = MagicMock() monkeypatch.setattr(embedding, "LOGGER", logger_mock) - actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6) + actual_json = mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM) requests_mock.post.assert_called() logger_mock.error.assert_called_once() @@ -75,6 +75,6 @@ def test_embedding_embed_SSLError( monkeypatch.setattr(embedding, "LOGGER", logger_mock) with pytest.raises(SSLError, match="mock exception"): - mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINI6) + mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM) logger_mock.error.assert_called_once_with(embedding.SSL_ERROR_MSG) From 2df398d96d091f51a5bd306754c2cefc4309cd0d Mon Sep 17 00:00:00 2001 From: Zapple Date: Wed, 13 Sep 2023 17:29:27 -0500 Subject: [PATCH 05/12] Add tests for embed endpoints --- python/tests/test_embedding.py | 61 ++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/python/tests/test_embedding.py b/python/tests/test_embedding.py index 0f09690..6d90bb3 100644 --- a/python/tests/test_embedding.py +++ b/python/tests/test_embedding.py @@ -78,3 +78,64 @@ def test_embedding_embed_SSLError( mock_embedding_client.embed(["asdf"], embedding.EmbeddingModel.MINILM) logger_mock.error.assert_called_once_with(embedding.SSL_ERROR_MSG) + + +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "asdf" + input_model = embedding.EmbeddingModel.MINILM + expected_item = [{"text": input_text, "metadata": None}] + + actual_json = mock_embedding_client.embed([input_text], input_model) + + embed_items_mock.assert_called_once_with( + text_embedding_items=expected_item, model=input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "asdf" + input_metadata = {"label": "asdf"} + input_model = embedding.EmbeddingModel.MINILM + expected_item = [{"text": input_text, "metadata": input_metadata}] + + actual_json = mock_embedding_client.embed_and_join_metadata( + [input_text], [input_metadata], input_model + ) + + embed_items_mock.assert_called_once_with( + text_embedding_items=expected_item, model=input_model + ) + + +@patch("starpoint.embedding.requests") +def test_embedding_embed_items( + requests_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, + monkeypatch: MonkeyPatch, +): + requests_mock.post().ok = True + test_value = {"mock_return": "value"} + requests_mock.post().json.return_value = test_value + + expected_json = {} + + actual_json = mock_embedding_client.embed_items( + [{"text": "asdf", "metadata": {"label": "asdf"}}], + embedding.EmbeddingModel.MINILM, + ) + + requests_mock.post.assert_called() + requests_mock.post().json.assert_called() + assert actual_json == test_value From 908e71f31b66614945b1d60de82f2de5d9bc7c2a Mon Sep 17 00:00:00 2001 From: Zapple Date: Thu, 14 Sep 2023 16:18:37 -0500 Subject: [PATCH 06/12] Function name and docstring changes and add logging for mismatch lists --- python/starpoint/embedding.py | 19 ++++++++++++++----- python/tests/test_embedding.py | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index de3db39..a8e8677 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -22,6 +22,11 @@ # Error and warning messages SSL_ERROR_MSG = "Request failed due to SSLError. Error is likely due to invalid API key. Please check if your API is correct and still valid." +TEXT_METADATA_LENGTH_MISMATCH_WARNING = ( + "The length of the texts and metadatas provided are different. There may be a mismatch " + "between texts and the metadatas length; this may cause undesired results between the joining of " + "embeddings and metadatas." +) class EmbeddingModel(Enum): @@ -44,8 +49,8 @@ def embed( model: EmbeddingModel, ) -> Dict[str, List[Dict]]: """Takes some texts creates embeddings using a model in starpoint. This is a - version of embed_and_join_metadata where joining metadata with the result is - not necessary. The same API is used between the two methods. + version of `embed_and_join_metadata_by_column` where joining metadata with the result is + not necessary. The same API is used for the two methods. Args: texts: List of strings to create embeddings from. @@ -60,14 +65,15 @@ def embed( text_embedding_items = [{"text": text, "metadata": None} for text in texts] return self.embed_items(text_embedding_items=text_embedding_items, model=model) - def embed_and_join_metadata( + def embed_and_join_metadata_by_columns( self, texts: List[str], metadatas: List[Dict], model: EmbeddingModel, ) -> Dict[str, List[Dict]]: - """Takes some texts and creates embeddings using a model in starpoint. Metadata is joined with - the results for ergonomics. Under the hood this is using embed_items. + """Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items` + instead, as mismatched `texts` and `metadatas` will output undesirable results. + Under the hood this is using `embed_items`. Args: texts: List of strings to create embeddings from. @@ -82,6 +88,9 @@ def embed_and_join_metadata( requests.exceptions.SSLError: Failure likely due to network issues. """ # TODO: add len + logging here + if len(texts) != len(metadatas): + LOGGER.warning(TEXT_METADATA_LENGTH_MISMATCH_WARNING) + text_embedding_items = [ { "text": text, diff --git a/python/tests/test_embedding.py b/python/tests/test_embedding.py index 6d90bb3..a752b62 100644 --- a/python/tests/test_embedding.py +++ b/python/tests/test_embedding.py @@ -100,7 +100,7 @@ def test_embedding_embed( @patch("starpoint.embedding.EmbeddingClient.embed_items") @patch("starpoint.embedding.requests") -def test_embedding_embed_and_join_metadata( +def test_embedding_embed_and_join_metadata_by_columns( requests_mock: MagicMock, embed_items_mock: MagicMock, mock_embedding_client: embedding.EmbeddingClient, @@ -110,7 +110,7 @@ def test_embedding_embed_and_join_metadata( input_model = embedding.EmbeddingModel.MINILM expected_item = [{"text": input_text, "metadata": input_metadata}] - actual_json = mock_embedding_client.embed_and_join_metadata( + actual_json = mock_embedding_client.embed_and_join_metadata_by_columns( [input_text], [input_metadata], input_model ) From 42f7da16daaebfb75b9ded41825f4000ec5db29d Mon Sep 17 00:00:00 2001 From: Zapple Date: Thu, 14 Sep 2023 16:19:04 -0500 Subject: [PATCH 07/12] Remove TODO --- python/starpoint/embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index a8e8677..1179600 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -87,7 +87,6 @@ def embed_and_join_metadata_by_columns( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ - # TODO: add len + logging here if len(texts) != len(metadatas): LOGGER.warning(TEXT_METADATA_LENGTH_MISMATCH_WARNING) From 8f3f654ae6bdf4a933e5a41bb7a171ed187da87a Mon Sep 17 00:00:00 2001 From: Zapple Date: Thu, 14 Sep 2023 16:47:37 -0500 Subject: [PATCH 08/12] Add guarding for embed_and_join_metadata_by_columns --- python/starpoint/embedding.py | 5 +++ python/tests/test_embedding.py | 71 +++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index 1179600..f5fe5a1 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -87,6 +87,11 @@ def embed_and_join_metadata_by_columns( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ + if not isinstance(texts, list): + raise ValueError("texts passed was not of type list") + if not isinstance(metadatas, list): + raise ValueError("metadatas passed was not of type list") + if len(texts) != len(metadatas): LOGGER.warning(TEXT_METADATA_LENGTH_MISMATCH_WARNING) diff --git a/python/tests/test_embedding.py b/python/tests/test_embedding.py index a752b62..8062be8 100644 --- a/python/tests/test_embedding.py +++ b/python/tests/test_embedding.py @@ -1,3 +1,4 @@ +from typing import List from unittest.mock import MagicMock, patch from uuid import UUID, uuid4 @@ -119,11 +120,79 @@ def test_embedding_embed_and_join_metadata_by_columns( ) +@patch("starpoint.embedding.EmbeddingClient.embed_items") @patch("starpoint.embedding.requests") -def test_embedding_embed_items( +def test_embedding_embed_and_join_metadata_by_columns_non_list_texts( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_metadata = {"label": "asdf"} + input_model = embedding.EmbeddingModel.MINILM + + with pytest.raises(ValueError): + mock_embedding_client.embed_and_join_metadata_by_columns( + "not_list_texts", [input_metadata], input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_by_columns_non_list_metadatas( requests_mock: MagicMock, + embed_items_mock: MagicMock, mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "asdf" + input_model = embedding.EmbeddingModel.MINILM + + with pytest.raises(ValueError): + mock_embedding_client.embed_and_join_metadata_by_columns( + [input_text], {"label": "not_list_metadatas"}, input_model + ) + + +@pytest.mark.parametrize( + "input_text,input_metadata", + [ + [ + ["embed_text1", "embed_text2"], + [{"label": "label1"}], + ], + [ + ["embed_text1"], + [{"label": "label1"}, {"label": "label2"}], + ], + ], +) +@patch("starpoint.embedding.EmbeddingClient.embed_items") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_by_columns_mismatch_list( + requests_mock: MagicMock, + embed_items_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, + input_text: List, + input_metadata: List, monkeypatch: MonkeyPatch, +): + input_model = embedding.EmbeddingModel.MINILM + + logger_mock = MagicMock() + monkeypatch.setattr(embedding, "LOGGER", logger_mock) + + actual_json = mock_embedding_client.embed_and_join_metadata_by_columns( + input_text, input_metadata, input_model + ) + + logger_mock.warning.assert_called_once_with( + embedding.TEXT_METADATA_LENGTH_MISMATCH_WARNING + ) + + +@patch("starpoint.embedding.requests") +def test_embedding_embed_items( + requests_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, ): requests_mock.post().ok = True test_value = {"mock_return": "value"} From 87092ef120bbac64bd6286d49881eaccfbb91879 Mon Sep 17 00:00:00 2001 From: Zapple Date: Fri, 15 Sep 2023 23:08:26 -0500 Subject: [PATCH 09/12] Add object-based embedding method --- python/pyproject.toml | 1 + python/starpoint/embedding.py | 48 ++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6d76adc..15540bc 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "requests~=2.31", "validators~=0.21", "pandas~=2.0", + "more-itertools~=10.1", ] [project.optional-dependencies] diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index f5fe5a1..b4055c8 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -1,6 +1,7 @@ import logging +import more_itertools from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Hashable, List, Optional from uuid import UUID import requests @@ -104,6 +105,50 @@ def embed_and_join_metadata_by_columns( ] return self.embed_items(text_embedding_items=text_embedding_items, model=model) + def embed_and_join_metadata( + self, + text_embedding_items: List[Dict], + embedding_key: Hashable, + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items` + instead, as mismatched `texts` and `metadatas` will output undesirable results. + Under the hood this is using `embed_items`. + + Args: + text_embedding_items: List of dicts of data to create embeddings from. + embedding_key: the key in the embedding items to use to generate the embeddings against + model: An enum choice from EmbeddingModel. + + Returns: + dict: Result with list of texts, metadata, and embeddings. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + texts = list(map(lambda item: item.get(embedding_key), text_embedding_items)) + if not texts: + raise ValueError( + "text_embedding_items received an empty list of list of empty items." + ) + elif not all(texts): + unqualified_indices = list( + more_itertools.locate(texts, lambda x: x is None) + ) + raise ValueError( + "The following indices had items that did not have the " + f"{embedding_key}:\n {unqualified_indices}" + ) + + # TODO: Figure out if we should do a deep copy here instead of editing the original dict + metadatas = list( + map(lambda item: item.pop(embedding_key), text_embedding_items) + ) + + return self.embed_and_join_metadata_by_columns( + texts=text_embedding_items, metadatas=metadatas, model=model + ) + def embed_items( self, text_embedding_items: List[Dict], @@ -122,6 +167,7 @@ def embed_items( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ + request_data = dict(items=text_embedding_items, model=model.value) try: response = requests.post( From c1a6085bc80cbdec1ca2af91564f3c664d2911aa Mon Sep 17 00:00:00 2001 From: Zapple Date: Fri, 15 Sep 2023 23:14:09 -0500 Subject: [PATCH 10/12] Add some notes --- python/starpoint/embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index b4055c8..bd63f28 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -141,6 +141,8 @@ def embed_and_join_metadata( ) # TODO: Figure out if we should do a deep copy here instead of editing the original dict + # We can also do this operation in the first map, but that might make additional operations we might consider + # doing in here a lot more annoying. Feels like an optimization that shouldn't happen right now. metadatas = list( map(lambda item: item.pop(embedding_key), text_embedding_items) ) From 81a86da318d2ec1b97a7543da1036fd0ad0162a7 Mon Sep 17 00:00:00 2001 From: Zapple Date: Fri, 15 Sep 2023 23:40:33 -0500 Subject: [PATCH 11/12] Clean up docstrings --- python/starpoint/embedding.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index bd63f28..b14e51d 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -72,8 +72,8 @@ def embed_and_join_metadata_by_columns( metadatas: List[Dict], model: EmbeddingModel, ) -> Dict[str, List[Dict]]: - """Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items` - instead, as mismatched `texts` and `metadatas` will output undesirable results. + """Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_and_join_metadata` or + `embed_items` instead, as mismatched `texts` and `metadatas` will output undesirable results. Under the hood this is using `embed_items`. Args: @@ -111,13 +111,14 @@ def embed_and_join_metadata( embedding_key: Hashable, model: EmbeddingModel, ) -> Dict[str, List[Dict]]: - """Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items` - instead, as mismatched `texts` and `metadatas` will output undesirable results. - Under the hood this is using `embed_items`. + """Takes some texts and creates embeddings using a model in starpoint, and joins them to + all additional data as metadata. Under the hood this is using `embed_and_join_metadata_by_columns` + which is using `embed_items`. Args: text_embedding_items: List of dicts of data to create embeddings from. - embedding_key: the key in the embedding items to use to generate the embeddings against + embedding_key: the key in each item used to create embeddings from. + e.g. `"context"` would be passed if each item looks like this: `{"context": "embed this text"}` model: An enum choice from EmbeddingModel. Returns: @@ -140,9 +141,8 @@ def embed_and_join_metadata( f"{embedding_key}:\n {unqualified_indices}" ) - # TODO: Figure out if we should do a deep copy here instead of editing the original dict - # We can also do this operation in the first map, but that might make additional operations we might consider - # doing in here a lot more annoying. Feels like an optimization that shouldn't happen right now. + # We can also do this operation in the first map that creates texts, but that might make additional operations + # in here a lot more annoying. It's an optimization that shouldn't happen right now. metadatas = list( map(lambda item: item.pop(embedding_key), text_embedding_items) ) From 4a01bf4e668a8cc3c3ecbd693e1f16fce03b6142 Mon Sep 17 00:00:00 2001 From: Zapple Date: Sat, 16 Sep 2023 00:13:14 -0500 Subject: [PATCH 12/12] Fix bugs and add tests --- python/starpoint/embedding.py | 15 +++----- python/tests/test_embedding.py | 68 ++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index b14e51d..83b0582 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -127,12 +127,11 @@ def embed_and_join_metadata( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ + if not text_embedding_items: + raise ValueError("text_embedding_items received an empty list.") + texts = list(map(lambda item: item.get(embedding_key), text_embedding_items)) - if not texts: - raise ValueError( - "text_embedding_items received an empty list of list of empty items." - ) - elif not all(texts): + if not all(texts): unqualified_indices = list( more_itertools.locate(texts, lambda x: x is None) ) @@ -143,12 +142,10 @@ def embed_and_join_metadata( # We can also do this operation in the first map that creates texts, but that might make additional operations # in here a lot more annoying. It's an optimization that shouldn't happen right now. - metadatas = list( - map(lambda item: item.pop(embedding_key), text_embedding_items) - ) + list(map(lambda item: item.pop(embedding_key), text_embedding_items)) return self.embed_and_join_metadata_by_columns( - texts=text_embedding_items, metadatas=metadatas, model=model + texts=texts, metadatas=text_embedding_items, model=model ) def embed_items( diff --git a/python/tests/test_embedding.py b/python/tests/test_embedding.py index 8062be8..9e7887b 100644 --- a/python/tests/test_embedding.py +++ b/python/tests/test_embedding.py @@ -189,6 +189,74 @@ def test_embedding_embed_and_join_metadata_by_columns_mismatch_list( ) +@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata( + requests_mock: MagicMock, + embed_and_join_metadata_by_columns_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "embed text" + input_metadata = { + "metadata1": "metadata1", + "metadata2": "metadata2", + } + embed_key = "text" + input_dict = {embed_key: input_text} + input_dict.update(input_metadata) + test_embedding_items = [input_dict] + + expected_item = [{"text": input_text, "metadata": input_metadata}] + input_model = embedding.EmbeddingModel.MINILM + + actual_json = mock_embedding_client.embed_and_join_metadata( + test_embedding_items, embed_key, input_model + ) + + embed_and_join_metadata_by_columns_mock.assert_called_once_with( + texts=[input_text], metadatas=[input_metadata], model=input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_no_embed_key( + requests_mock: MagicMock, + embed_and_join_metadata_by_columns_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_text = "embed text" + input_metadata = { + "metadata1": "metadata1", + "metadata2": "metadata2", + } + input_dict = {"text": input_text} + input_dict.update(input_metadata) + test_embedding_items = [input_dict] + + input_model = embedding.EmbeddingModel.MINILM + embed_key = "no key" + + with pytest.raises(ValueError): + actual_json = mock_embedding_client.embed_and_join_metadata( + test_embedding_items, embed_key, input_model + ) + + +@patch("starpoint.embedding.EmbeddingClient.embed_and_join_metadata_by_columns") +@patch("starpoint.embedding.requests") +def test_embedding_embed_and_join_metadata_no_values( + requests_mock: MagicMock, + embed_and_join_metadata_by_columns_mock: MagicMock, + mock_embedding_client: embedding.EmbeddingClient, +): + input_model = embedding.EmbeddingModel.MINILM + with pytest.raises(ValueError): + actual_json = mock_embedding_client.embed_and_join_metadata( + [], "text", input_model + ) + + @patch("starpoint.embedding.requests") def test_embedding_embed_items( requests_mock: MagicMock,