Skip to content

Commit

Permalink
fix openai version
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik committed Nov 28, 2023
1 parent 5f90564 commit 3316648
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 6 additions & 3 deletions align_data/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List, Tuple, Dict, Any, Optional, Callable
from functools import wraps

import openai
from openai import OpenAI

from langchain.embeddings import HuggingFaceEmbeddings
from openai import (
OpenAIError,
Expand All @@ -21,11 +22,13 @@
from align_data.settings import (
USE_OPENAI_EMBEDDINGS,
OPENAI_EMBEDDINGS_MODEL,
OPENAI_ORGANIZATION,
EMBEDDING_LENGTH_BIAS,
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
DEVICE,
)

client = OpenAI(organization=OPENAI_ORGANIZATION)

# --------------------
# CONSTANTS & CONFIGURATION
Expand Down Expand Up @@ -90,7 +93,7 @@ def wrapper(*args, **kwargs):
@handle_openai_errors
def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]:
"""Process a batch for moderation checks."""
return openai.Moderation.create(input=batch)["results"]
return client.moderations.create(input=batch)["results"]


def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]:
Expand Down Expand Up @@ -125,7 +128,7 @@ def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counte
@handle_openai_errors
def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]:
"""Compute embeddings for a batch."""
batch_data = openai.Embedding.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
batch_data = client.embeddings.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
return [d["embedding"] for d in batch_data]


Expand Down
3 changes: 1 addition & 2 deletions align_data/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@
OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002"
OPENAI_EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
openai.organization = os.environ.get("OPENAI_ORGANIZATION", None)
OPENAI_ORGANIZATION = os.environ.get("OPENAI_ORGANIZATION", None)

SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768
Expand Down

0 comments on commit 3316648

Please sign in to comment.