diff --git a/llms/providers/anthropic.py b/llms/providers/anthropic.py index b52107c..a05c007 100644 --- a/llms/providers/anthropic.py +++ b/llms/providers/anthropic.py @@ -1,7 +1,6 @@ # llms/providers/anthropic.py -import os -from typing import AsyncGenerator, Dict, Generator, List, Optional +from typing import AsyncGenerator, Dict, Generator, List, Optional, Union import anthropic @@ -28,18 +27,26 @@ class AnthropicProvider(BaseProvider): "completion": 5.51, "token_limit": 100_000, }, - "claude-2": {"prompt": 11.02, "completion": 32.68, "token_limit": 100_000}, + "claude-2": {"prompt": 8.00, "completion": 24.00, "token_limit": 200_000}, } - def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None): + def __init__( + self, + api_key: Union[str, None] = None, + model: Union[str, None] = None, + client_kwargs: Union[dict, None] = None, + async_client_kwargs: Union[dict, None] = None, + ): if model is None: model = list(self.MODEL_INFO.keys())[0] self.model = model - if api_key is None: - api_key = os.getenv("ANTHROPIC_API_KEY") - self.client = anthropic.Anthropic(api_key=api_key) - self.async_client = anthropic.AsyncAnthropic(api_key=api_key) + if client_kwargs is None: + client_kwargs = {} + self.client = anthropic.Anthropic(api_key=api_key, **client_kwargs) + if async_client_kwargs is None: + async_client_kwargs = {} + self.async_client = anthropic.AsyncAnthropic(api_key=api_key, **async_client_kwargs) def count_tokens(self, content: str) -> int: return self.client.count_tokens(content) diff --git a/llms/providers/openai.py b/llms/providers/openai.py index 3ba3125..4ea98b8 100644 --- a/llms/providers/openai.py +++ b/llms/providers/openai.py @@ -17,12 +17,22 @@ class OpenAIProvider(BaseProvider): "gpt-4-1106-preview": {"prompt": 10.0, "completion": 20.0, "token_limit": 128000, "is_chat": True}, } - def __init__(self, api_key=None, model=None): + def __init__( + self, + api_key: Union[str, None] = None, + model: Union[str, None] = None, + client_kwargs: Union[dict, None] = None, + async_client_kwargs: Union[dict, None] = None, + ): if model is None: model = list(self.MODEL_INFO.keys())[0] self.model = model - self.client = OpenAI(api_key=api_key) - self.async_client = AsyncOpenAI(api_key=api_key) + if client_kwargs is None: + client_kwargs = {} + self.client = OpenAI(api_key=api_key, **client_kwargs) + if async_client_kwargs is None: + async_client_kwargs = {} + self.async_client = AsyncOpenAI(api_key=api_key, **async_client_kwargs) @property def is_chat_model(self) -> bool: