Skip to content

Commit

Permalink
Merge pull request #176 from StampyAI/fix-large-batch-embedding
Browse files Browse the repository at this point in the history
fix 2048+ batch embedding
  • Loading branch information
henri123lemoine authored Sep 1, 2023
2 parents fd4a587 + 68e32b8 commit 4782ec3
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions align_data/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,34 +88,39 @@ def wrapper(*args, **kwargs):


@handle_openai_errors
def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]:
"""
Check moderation on a list of texts.
def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]:
"""Process a batch for moderation checks."""
return openai.Moderation.create(input=batch)["results"]

Parameters:
- texts (List[str]): List of texts to be checked for moderation.
- max_texts_num (int): Number of texts to check at once. Defaults to 32.

Returns:
- List[ModerationInfoType]: List of moderation results for the provided texts.
"""
total_texts = len(texts)
results = []
def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]:
"""Batch moderation checks on list of texts."""
return [
result
for batch in (texts[i : i + max_texts_num] for i in range(0, len(texts), max_texts_num))
for result in _single_batch_moderation_check(batch)
]

for i in range(0, total_texts, max_texts_num):
batch_texts = texts[i : i + max_texts_num]
batch_results = openai.Moderation.create(input=batch_texts)["results"]
results.extend(batch_results)

return results
@handle_openai_errors
def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]:
"""Compute embeddings for a batch."""
batch_data = openai.Embedding.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
return [d["embedding"] for d in batch_data]


@handle_openai_errors
def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]:
data = openai.Embedding.create(
input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs
).data
return [d["embedding"] for d in data]
def _compute_openai_embeddings(
non_flagged_texts: List[str], max_texts_num: int = 2048, **kwargs
) -> List[List[float]]:
"""Batch computation of embeddings for non-flagged texts."""
return [
embedding
for batch in (
non_flagged_texts[i : i + max_texts_num]
for i in range(0, len(non_flagged_texts), max_texts_num)
)
for embedding in _single_batch_compute_openai_embeddings(batch, **kwargs)
]


def get_embeddings_without_moderation(
Expand Down Expand Up @@ -193,7 +198,6 @@ def get_embeddings(
Returns:
- Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results.
"""
assert len(texts) <= 2048, "The batch size should not be larger than 2048."
assert all(texts), "No empty strings allowed in the input list."

# replace newlines, which can negatively affect performance
Expand Down

0 comments on commit 4782ec3

Please sign in to comment.