Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix moderation batching #191

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions align_data/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Tuple, Dict, Any, Optional
from typing import List, Tuple, Dict, Any, Optional, Callable

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to reconsider the wisdom against wild imports, for the specific case of from typing import * it might not be as evil as in the general case 🤔

from functools import wraps

import openai
Expand Down Expand Up @@ -93,11 +93,31 @@ def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]
return openai.Moderation.create(input=batch)["results"]


def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]:
"""Batch moderation checks on list of texts."""
def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this uses len to calculate token usage, which will overestimate by something like 2-3 times. It's also used elsewhere here, from what I can see, so I left it for now

"""Batch moderation checks on list of texts.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to document the part that explains "what is a moderation check", not the part that explains that code with batch variable is doing some batching 😅


:param List[str] texts: the texts to be checked
:param int max_batch_size: the max size in tokens for a single batch
:param Callable[[str], int] tokens_counter: the function used to count tokens
"""
# A very ugly loop that will split the `texts` into smaller batches so that the
# total sum of tokens in each batch will not exceed `max_batch_size`
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as the comment says. This is very ugly. I couldn't think of anything better that wouldn't be a lot more complicated

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no idea what part of this loop is ugly, but we have https://github.com/StampyAI/stampy/blob/853e28b2e002f50d5861583cf09254093fd4e397/utilities/utilities.py#L307 in stampy bot if you want some inspiration (but this one looks better TBH) :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, both are terrible :P I wanted something that would do it in say around 2-3 lines. This offends my sensibilities. But it's better than any alternatives :/

parts = []
part = []
part_count = 0
for item in texts:
if part_count + tokens_counter(item) > max_batch_size:
parts.append(part)
part = []
part_count = 0
part.append(item)
part_count += tokens_counter(item)
if part:
parts.append(part)

return [
result
for batch in (texts[i : i + max_texts_num] for i in range(0, len(texts), max_texts_num))
for batch in parts
for result in _single_batch_moderation_check(batch)
]

Expand Down
Loading