diff --git a/align_data/embeddings/embedding_utils.py b/align_data/embeddings/embedding_utils.py index b0b1d16..0fc4ae2 100644 --- a/align_data/embeddings/embedding_utils.py +++ b/align_data/embeddings/embedding_utils.py @@ -94,7 +94,7 @@ def wrapper(*args, **kwargs): @handle_openai_errors def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]: """Process a batch for moderation checks.""" - return client.moderations.create(input=batch)["results"] + return client.moderations.create(input=batch).results def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]: @@ -129,8 +129,8 @@ def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counte @handle_openai_errors def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]: """Compute embeddings for a batch.""" - batch_data = client.embeddings.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data - return [d["embedding"] for d in batch_data] + batch_data = client.embeddings.create(input=batch, model=OPENAI_EMBEDDINGS_MODEL, **kwargs).data + return [d.embedding for d in batch_data] def _compute_openai_embeddings( @@ -199,7 +199,7 @@ def get_embeddings_or_none_if_flagged( - Tuple[Optional[List[List[float]]], ModerationInfoListType]: Tuple containing the list of embeddings (or None if any text is flagged) and the moderation results. """ moderation_results = moderation_check(texts) - if any(result["flagged"] for result in moderation_results): + if any(result.flagged for result in moderation_results): return None, moderation_results embeddings = get_embeddings_without_moderation(texts, source, **kwargs) @@ -229,7 +229,7 @@ def get_embeddings( # Check all texts for moderation flags moderation_results = moderation_check(texts) - flags = [result["flagged"] for result in moderation_results] + flags = [result.flagged for result in moderation_results] non_flagged_texts = [text for text, flag in zip(texts, flags) if not flag] non_flagged_embeddings = get_embeddings_without_moderation(non_flagged_texts, source, **kwargs) diff --git a/align_data/embeddings/pinecone/update_pinecone.py b/align_data/embeddings/pinecone/update_pinecone.py index f1a06f3..75f64f6 100644 --- a/align_data/embeddings/pinecone/update_pinecone.py +++ b/align_data/embeddings/pinecone/update_pinecone.py @@ -112,11 +112,11 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: text_chunks = get_text_chunks(article, self.text_splitter) embeddings, moderation_results = get_embeddings(text_chunks, article.source) - if any(result["flagged"] for result in moderation_results): + if any(result.flagged for result in moderation_results): flagged_text_chunks = [ f'Chunk {i}: "{text}"' for i, (text, result) in enumerate(zip(text_chunks, moderation_results)) - if result["flagged"] + if result.flagged ] logger.warning( f"OpenAI moderation flagged text chunks for the following article: {article.id}" @@ -137,12 +137,12 @@ def _make_pinecone_entry(self, article: Article) -> PineconeEntry | None: confidence=article.confidence, ) except ( - ValueError, - TypeError, - AttributeError, + # ValueError, + # TypeError, + # AttributeError, ValidationError, - MissingFieldsError, - MissingEmbeddingModelError, + # MissingFieldsError, + # MissingEmbeddingModelError, ) as e: logger.warning(e) article.append_comment(f"Error encountered while processing this article: {e}")