Skip to content

Commit

Permalink
update to newer library
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Jun 4, 2024
1 parent ffb5cdf commit b0e2dba
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions align_data/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def wrapper(*args, **kwargs):
@handle_openai_errors
def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]:
"""Process a batch for moderation checks."""
return client.moderations.create(input=batch)["results"]
return client.moderations.create(input=batch).results


def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]:
Expand Down Expand Up @@ -129,8 +129,8 @@ def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counte
@handle_openai_errors
def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]:
"""Compute embeddings for a batch."""
batch_data = client.embeddings.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
return [d["embedding"] for d in batch_data]
batch_data = client.embeddings.create(input=batch, model=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
return [d.embedding for d in batch_data]


def _compute_openai_embeddings(
Expand Down Expand Up @@ -199,7 +199,7 @@ def get_embeddings_or_none_if_flagged(
- Tuple[Optional[List[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results.
"""
moderation_results = moderation_check(texts)
if any(result["flagged"] for result in moderation_results):
if any(result.flagged for result in moderation_results):
return None, moderation_results

embeddings = get_embeddings_without_moderation(texts, source, **kwargs)
Expand Down Expand Up @@ -229,7 +229,7 @@ def get_embeddings(

# Check all texts for moderation flags
moderation_results = moderation_check(texts)
flags = [result["flagged"] for result in moderation_results]
flags = [result.flagged for result in moderation_results]

non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag]
non_flagged_embeddings = get_embeddings_without_moderation(non_flagged_texts, source, **kwargs)
Expand Down
14 changes: 7 additions & 7 deletions align_data/embeddings/pinecone/update_pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None:
text_chunks = get_text_chunks(article, self.text_splitter)
embeddings, moderation_results = get_embeddings(text_chunks, article.source)

if any(result["flagged"] for result in moderation_results):
if any(result.flagged for result in moderation_results):
flagged_text_chunks = [
f'Chunk {i}: "{text}"'
for i, (text, result) in enumerate(zip(text_chunks, moderation_results))
if result["flagged"]
if result.flagged
]
logger.warning(
f"OpenAI moderation flagged text chunks for the following article: {article.id}"
Expand All @@ -137,12 +137,12 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None:
confidence=article.confidence,
)
except (
ValueError,
TypeError,
AttributeError,
# ValueError,
# TypeError,
# AttributeError,
ValidationError,
MissingFieldsError,
MissingEmbeddingModelError,
# MissingFieldsError,
# MissingEmbeddingModelError,
) as e:
logger.warning(e)
article.append_comment(f"Error encountered while processing this article: {e}")
Expand Down

0 comments on commit b0e2dba

Please sign in to comment.