diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py index ff38632..54acf26 100644 --- a/align_data/embeddings/embedding_utils.py +++ b/align_data/embeddings/embedding_utils.py @@ -1,5 +1,5 @@ import logging -from typing import List, Tuple, Dict, Any, Optional +from typing import List, Tuple, Dict, Any, Optional, Callable from functools import wraps import openai @@ -93,11 +93,31 @@ def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType] return openai.Moderation.create(input=batch)["results"] -def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]: - """Batch moderation checks on list of texts.""" +def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]: + """Batch moderation checks on list of texts. + + :param List[str] texts: the texts to be checked + :param int max_batch_size: the max size in tokens for a single batch + :param Callable[[str], int] tokens_counter: the function used to count tokens + """ + # A very ugly loop that will split the `texts` into smaller batches so that the + # total sum of tokens in each batch will not exceed `max_batch_size` + parts = [] + part = [] + part_count = 0 + for item in texts: + if part_count + tokens_counter(item) > max_batch_size: + parts.append(part) + part = [] + part_count = 0 + part.append(item) + part_count += tokens_counter(item) + if part: + parts.append(part) + return [ result - for batch in (texts[i : i + max_texts_num] for i in range(0, len(texts), max_texts_num)) + for batch in parts for result in _single_batch_moderation_check(batch) ]