Skip to content

Commit

Permalink
Merge pull request #20 from kagisearch/openai_migration
Browse files Browse the repository at this point in the history
Add kwargs to client init
  • Loading branch information
bkiat1123 authored Nov 22, 2023
2 parents a113fdd + 98c6dd7 commit 712bfaf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
23 changes: 15 additions & 8 deletions llms/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# llms/providers/anthropic.py

import os
from typing import AsyncGenerator, Dict, Generator, List, Optional
from typing import AsyncGenerator, Dict, Generator, List, Optional, Union

import anthropic

Expand All @@ -28,18 +27,26 @@ class AnthropicProvider(BaseProvider):
"completion": 5.51,
"token_limit": 100_000,
},
"claude-2": {"prompt": 11.02, "completion": 32.68, "token_limit": 100_000},
"claude-2": {"prompt": 8.00, "completion": 24.00, "token_limit": 200_000},
}

def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
def __init__(
self,
api_key: Union[str, None] = None,
model: Union[str, None] = None,
client_kwargs: Union[dict, None] = None,
async_client_kwargs: Union[dict, None] = None,
):
if model is None:
model = list(self.MODEL_INFO.keys())[0]
self.model = model

if api_key is None:
api_key = os.getenv("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(api_key=api_key)
self.async_client = anthropic.AsyncAnthropic(api_key=api_key)
if client_kwargs is None:
client_kwargs = {}
self.client = anthropic.Anthropic(api_key=api_key, **client_kwargs)
if async_client_kwargs is None:
async_client_kwargs = {}
self.async_client = anthropic.AsyncAnthropic(api_key=api_key, **async_client_kwargs)

def count_tokens(self, content: str) -> int:
return self.client.count_tokens(content)
Expand Down
16 changes: 13 additions & 3 deletions llms/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,22 @@ class OpenAIProvider(BaseProvider):
"gpt-4-1106-preview": {"prompt": 10.0, "completion": 20.0, "token_limit": 128000, "is_chat": True},
}

def __init__(self, api_key=None, model=None):
def __init__(
self,
api_key: Union[str, None] = None,
model: Union[str, None] = None,
client_kwargs: Union[dict, None] = None,
async_client_kwargs: Union[dict, None] = None,
):
if model is None:
model = list(self.MODEL_INFO.keys())[0]
self.model = model
self.client = OpenAI(api_key=api_key)
self.async_client = AsyncOpenAI(api_key=api_key)
if client_kwargs is None:
client_kwargs = {}
self.client = OpenAI(api_key=api_key, **client_kwargs)
if async_client_kwargs is None:
async_client_kwargs = {}
self.async_client = AsyncOpenAI(api_key=api_key, **async_client_kwargs)

@property
def is_chat_model(self) -> bool:
Expand Down

0 comments on commit 712bfaf

Please sign in to comment.