diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py index cf98228c..ff386325 100644 --- a/align_data/embeddings/embedding_utils.py +++ b/align_data/embeddings/embedding_utils.py @@ -88,34 +88,39 @@ def wrapper(*args, **kwargs): @handle_openai_errors -def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]: - """ - Check moderation on a list of texts. +def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]: + """Process a batch for moderation checks.""" + return openai.Moderation.create(input=batch)["results"] - Parameters: - - texts (List[str]): List of texts to be checked for moderation. - - max_texts_num (int): Number of texts to check at once. Defaults to 32. - Returns: - - List[ModerationInfoType]: List of moderation results for the provided texts. - """ - total_texts = len(texts) - results = [] +def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]: + """Batch moderation checks on list of texts.""" + return [ + result + for batch in (texts[i : i + max_texts_num] for i in range(0, len(texts), max_texts_num)) + for result in _single_batch_moderation_check(batch) + ] - for i in range(0, total_texts, max_texts_num): - batch_texts = texts[i : i + max_texts_num] - batch_results = openai.Moderation.create(input=batch_texts)["results"] - results.extend(batch_results) - return results +@handle_openai_errors +def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]: + """Compute embeddings for a batch.""" + batch_data = openai.Embedding.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + return [d["embedding"] for d in batch_data] -@handle_openai_errors -def _compute_openai_embeddings(non_flagged_texts: List[str], **kwargs) -> List[List[float]]: - data = openai.Embedding.create( - input=non_flagged_texts, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs - ).data - return [d["embedding"] for d in data] +def _compute_openai_embeddings( + non_flagged_texts: List[str], max_texts_num: int = 2048, **kwargs +) -> List[List[float]]: + """Batch computation of embeddings for non-flagged texts.""" + return [ + embedding + for batch in ( + non_flagged_texts[i : i + max_texts_num] + for i in range(0, len(non_flagged_texts), max_texts_num) + ) + for embedding in _single_batch_compute_openai_embeddings(batch, **kwargs) + ] def get_embeddings_without_moderation( @@ -193,7 +198,6 @@ def get_embeddings( Returns: - Tuple[List[Optional[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (with None for flagged texts) and the moderation results. """ - assert len(texts) <= 2048, "The batch size should not be larger than 2048." assert all(texts), "No empty strings allowed in the input list." # replace newlines, which can negatively affect performance