Skip to content

Commit

Permalink
Add OpenAI v1.x.x compatibility to embedding models and api base
Browse files Browse the repository at this point in the history
- Support passing the client as the parameter
- In OpenAI v1.x.x, `api_base` has been renamed into `base_url`
openai/openai-python#1051 (comment)

Signed-off-by: Hollow Man <[email protected]>
  • Loading branch information
HollowMan6 committed Sep 2, 2024
1 parent 4547976 commit 12d228b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
5 changes: 4 additions & 1 deletion gptcache/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def set_azure_openai_key():

openai.api_type = "azure"
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base = os.getenv("OPENAI_API_BASE")
if hasattr(openai, "api_base"):
openai.api_base = os.getenv("OPENAI_API_BASE")
elif hasattr(openai, "base_url"):
openai.base_url = os.getenv("OPENAI_BASE_URL", os.getenv("OPENAI_API_BASE"))
openai.api_version = os.getenv("OPENAI_API_VERSION")

cache = Cache()
4 changes: 2 additions & 2 deletions gptcache/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def Cohere(model="large", api_key=None):
return cohere.Cohere(model, api_key)


def OpenAI(model="text-embedding-ada-002", api_key=None):
return openai.OpenAI(model, api_key)
def OpenAI(model="text-embedding-ada-002", api_key=None, api_base=None, client=None):
return openai.OpenAI(model, api_key, api_base, client)


def Huggingface(model="distilbert-base-uncased"):
Expand Down
17 changes: 12 additions & 5 deletions gptcache/embedding/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,23 @@ class OpenAI(BaseEmbedding):
embed = encoder.to_embeddings(test_sentence)
"""

def __init__(self, model: str = "text-embedding-ada-002", api_key: str = None, api_base: str = None):
def __init__(self, model: str = "text-embedding-ada-002", api_key: str = None, api_base: str = None, client = None):
if not api_key:
if openai.api_key:
api_key = openai.api_key
else:
api_key = os.getenv("OPENAI_API_KEY")
if not api_base:
if openai.api_base:
if hasattr(openai, "api_base") and openai.api_base:
api_base = openai.api_base
elif hasattr(openai, "base_url") and openai.base_url:
api_base = openai.base_url
else:
api_base = os.getenv("OPENAI_API_BASE")
api_base = os.getenv("OPENAI_API_BASE", os.getenv("OPENAI_BASE_URL"))
openai.api_key = api_key
self.api_base = api_base # don't override all of openai as we may just want to override for say embeddings
self.model = model
self.client = client
if model in self.dim_dict():
self.__dimension = self.dim_dict()[model]
else:
Expand All @@ -56,8 +59,12 @@ def to_embeddings(self, data, **_):
:return: a text embedding in shape of (dim,).
"""
sentence_embeddings = openai.Embedding.create(model=self.model, input=data, api_base=self.api_base)
return np.array(sentence_embeddings["data"][0]["embedding"]).astype("float32")
if self.client:
sentence_embeddings = self.client.embeddings.create(model=self.model, input=data)
return np.array(sentence_embeddings.data[0].embedding).astype("float32")
else:
sentence_embeddings = openai.Embedding.create(model=self.model, input=data, api_base=self.api_base)
return np.array(sentence_embeddings["data"][0]["embedding"]).astype("float32")

@property
def dimension(self):
Expand Down

0 comments on commit 12d228b

Please sign in to comment.