-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix moderation batching #191
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import logging | ||
from typing import List, Tuple, Dict, Any, Optional | ||
from typing import List, Tuple, Dict, Any, Optional, Callable | ||
from functools import wraps | ||
|
||
import openai | ||
|
@@ -93,11 +93,31 @@ def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType] | |
return openai.Moderation.create(input=batch)["results"] | ||
|
||
|
||
def moderation_check(texts: List[str], max_texts_num: int = 32) -> List[ModerationInfoType]: | ||
"""Batch moderation checks on list of texts.""" | ||
def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this uses |
||
"""Batch moderation checks on list of texts. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to document the part that explains "what is a moderation check", not the part that explains that code with |
||
:param List[str] texts: the texts to be checked | ||
:param int max_batch_size: the max size in tokens for a single batch | ||
:param Callable[[str], int] tokens_counter: the function used to count tokens | ||
""" | ||
# A very ugly loop that will split the `texts` into smaller batches so that the | ||
# total sum of tokens in each batch will not exceed `max_batch_size` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as the comment says. This is very ugly. I couldn't think of anything better that wouldn't be a lot more complicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no idea what part of this loop is ugly, but we have https://github.com/StampyAI/stampy/blob/853e28b2e002f50d5861583cf09254093fd4e397/utilities/utilities.py#L307 in stampy bot if you want some inspiration (but this one looks better TBH) :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup, both are terrible :P I wanted something that would do it in say around 2-3 lines. This offends my sensibilities. But it's better than any alternatives :/ |
||
parts = [] | ||
part = [] | ||
part_count = 0 | ||
for item in texts: | ||
if part_count + tokens_counter(item) > max_batch_size: | ||
parts.append(part) | ||
part = [] | ||
part_count = 0 | ||
part.append(item) | ||
part_count += tokens_counter(item) | ||
if part: | ||
parts.append(part) | ||
|
||
return [ | ||
result | ||
for batch in (texts[i : i + max_texts_num] for i in range(0, len(texts), max_texts_num)) | ||
for batch in parts | ||
for result in _single_batch_moderation_check(batch) | ||
] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm starting to reconsider the wisdom against wild imports, for the specific case of
from typing import *
it might not be as evil as in the general case 🤔