diff --git a/python/pyproject.toml b/python/pyproject.toml index 6d76adc..15540bc 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "requests~=2.31", "validators~=0.21", "pandas~=2.0", + "more-itertools~=10.1", ] [project.optional-dependencies] diff --git a/python/starpoint/embedding.py b/python/starpoint/embedding.py index f5fe5a1..b4055c8 100644 --- a/python/starpoint/embedding.py +++ b/python/starpoint/embedding.py @@ -1,6 +1,7 @@ import logging +import more_itertools from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Hashable, List, Optional from uuid import UUID import requests @@ -104,6 +105,50 @@ def embed_and_join_metadata_by_columns( ] return self.embed_items(text_embedding_items=text_embedding_items, model=model) + def embed_and_join_metadata( + self, + text_embedding_items: List[Dict], + embedding_key: Hashable, + model: EmbeddingModel, + ) -> Dict[str, List[Dict]]: + """Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items` + instead, as mismatched `texts` and `metadatas` will output undesirable results. + Under the hood this is using `embed_items`. + + Args: + text_embedding_items: List of dicts of data to create embeddings from. + embedding_key: the key in the embedding items to use to generate the embeddings against + model: An enum choice from EmbeddingModel. + + Returns: + dict: Result with list of texts, metadata, and embeddings. + + Raises: + requests.exceptions.SSLError: Failure likely due to network issues. + """ + texts = list(map(lambda item: item.get(embedding_key), text_embedding_items)) + if not texts: + raise ValueError( + "text_embedding_items received an empty list of list of empty items." + ) + elif not all(texts): + unqualified_indices = list( + more_itertools.locate(texts, lambda x: x is None) + ) + raise ValueError( + "The following indices had items that did not have the " + f"{embedding_key}:\n {unqualified_indices}" + ) + + # TODO: Figure out if we should do a deep copy here instead of editing the original dict + metadatas = list( + map(lambda item: item.pop(embedding_key), text_embedding_items) + ) + + return self.embed_and_join_metadata_by_columns( + texts=text_embedding_items, metadatas=metadatas, model=model + ) + def embed_items( self, text_embedding_items: List[Dict], @@ -122,6 +167,7 @@ def embed_items( Raises: requests.exceptions.SSLError: Failure likely due to network issues. """ + request_data = dict(items=text_embedding_items, model=model.value) try: response = requests.post(