Skip to content

Commit

Permalink
Add object-based embedding method
Browse files Browse the repository at this point in the history
  • Loading branch information
FullMetalMeowchemist committed Sep 16, 2023
1 parent 3785010 commit 87092ef
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"requests~=2.31",
"validators~=0.21",
"pandas~=2.0",
"more-itertools~=10.1",
]

[project.optional-dependencies]
Expand Down
48 changes: 47 additions & 1 deletion python/starpoint/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import more_itertools
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Hashable, List, Optional
from uuid import UUID

import requests
Expand Down Expand Up @@ -104,6 +105,50 @@ def embed_and_join_metadata_by_columns(
]
return self.embed_items(text_embedding_items=text_embedding_items, model=model)

def embed_and_join_metadata(
self,
text_embedding_items: List[Dict],
embedding_key: Hashable,
model: EmbeddingModel,
) -> Dict[str, List[Dict]]:
"""Takes some texts and creates embeddings using a model in starpoint. Prefer using `embed_items`
instead, as mismatched `texts` and `metadatas` will output undesirable results.
Under the hood this is using `embed_items`.
Args:
text_embedding_items: List of dicts of data to create embeddings from.
embedding_key: the key in the embedding items to use to generate the embeddings against
model: An enum choice from EmbeddingModel.
Returns:
dict: Result with list of texts, metadata, and embeddings.
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""
texts = list(map(lambda item: item.get(embedding_key), text_embedding_items))
if not texts:
raise ValueError(
"text_embedding_items received an empty list of list of empty items."
)
elif not all(texts):
unqualified_indices = list(
more_itertools.locate(texts, lambda x: x is None)
)
raise ValueError(
"The following indices had items that did not have the "
f"{embedding_key}:\n {unqualified_indices}"
)

# TODO: Figure out if we should do a deep copy here instead of editing the original dict
metadatas = list(
map(lambda item: item.pop(embedding_key), text_embedding_items)
)

return self.embed_and_join_metadata_by_columns(
texts=text_embedding_items, metadatas=metadatas, model=model
)

def embed_items(
self,
text_embedding_items: List[Dict],
Expand All @@ -122,6 +167,7 @@ def embed_items(
Raises:
requests.exceptions.SSLError: Failure likely due to network issues.
"""

request_data = dict(items=text_embedding_items, model=model.value)
try:
response = requests.post(
Expand Down

0 comments on commit 87092ef

Please sign in to comment.