Skip to content

Commit

Permalink
fix openai version (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
mruwnik authored Nov 28, 2023
1 parent 5f90564 commit 119d428
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/fetch-dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
CODA_TOKEN: ${{ secrets.CODA_TOKEN || inputs.coda_token }}
AIRTABLE_API_KEY: ${{ secrets.AIRTABLE_API_KEY || inputs.airtable_api_key }}
YOUTUBE_API_KEY: ${{ secrets.YOUTUBE_API_KEY || inputs.youtube_api_key }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY || inputs.openai_api_key }}
ARD_DB_USER: ${{ secrets.ARD_DB_USER || inputs.db_user }}
ARD_DB_PASSWORD: ${{ secrets.ARD_DB_PASSWORD || inputs.db_password }}
ARD_DB_HOST: ${{ secrets.ARD_DB_HOST || inputs.db_host }}
Expand Down
10 changes: 7 additions & 3 deletions align_data/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List, Tuple, Dict, Any, Optional, Callable
from functools import wraps

import openai
from openai import OpenAI

from langchain.embeddings import HuggingFaceEmbeddings
from openai import (
OpenAIError,
Expand All @@ -21,11 +22,14 @@
from align_data.settings import (
USE_OPENAI_EMBEDDINGS,
OPENAI_EMBEDDINGS_MODEL,
OPENAI_API_KEY,
OPENAI_ORGANIZATION,
EMBEDDING_LENGTH_BIAS,
SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL,
DEVICE,
)

client = OpenAI(api_key=OPENAI_API_KEY, organization=OPENAI_ORGANIZATION)

# --------------------
# CONSTANTS & CONFIGURATION
Expand Down Expand Up @@ -90,7 +94,7 @@ def wrapper(*args, **kwargs):
@handle_openai_errors
def _single_batch_moderation_check(batch: List[str]) -> List[ModerationInfoType]:
"""Process a batch for moderation checks."""
return openai.Moderation.create(input=batch)["results"]
return client.moderations.create(input=batch)["results"]


def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counter: Callable[[str], int] = len) -> List[ModerationInfoType]:
Expand Down Expand Up @@ -125,7 +129,7 @@ def moderation_check(texts: List[str], max_batch_size: int = 4096, tokens_counte
@handle_openai_errors
def _single_batch_compute_openai_embeddings(batch: List[str], **kwargs) -> List[List[float]]:
"""Compute embeddings for a batch."""
batch_data = openai.Embedding.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
batch_data = client.embeddings.create(input=batch, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs).data
return [d["embedding"] for d in batch_data]


Expand Down
4 changes: 2 additions & 2 deletions align_data/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
OPENAI_EMBEDDINGS_MODEL = "text-embedding-ada-002"
OPENAI_EMBEDDINGS_DIMS = 1536
OPENAI_EMBEDDINGS_RATE_LIMIT = 3500
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
openai.organization = os.environ.get("OPENAI_ORGANIZATION", None)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
OPENAI_ORGANIZATION = os.environ.get("OPENAI_ORGANIZATION", None)

SENTENCE_TRANSFORMER_EMBEDDINGS_MODEL = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
SENTENCE_TRANSFORMER_EMBEDDINGS_DIMS = 768
Expand Down

0 comments on commit 119d428

Please sign in to comment.