-
Notifications
You must be signed in to change notification settings - Fork 901
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #175 from Riddhimaan-Senapati/main
Added support for deepseek LLMs
- Loading branch information
Showing
4 changed files
with
127 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import openai | ||
import os | ||
from aisuite.provider import Provider, LLMError | ||
|
||
|
||
class DeepseekProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the DeepSeek provider with the given configuration. | ||
Pass the entire configuration dictionary to the OpenAI client constructor. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
config.setdefault("api_key", os.getenv("DEEPSEEK_API_KEY")) | ||
if not config["api_key"]: | ||
raise ValueError( | ||
"DeepSeek API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." | ||
) | ||
config["base_url"] = "https://api.deepseek.com" | ||
|
||
# NOTE: We could choose to remove above lines for api_key since OpenAI will automatically | ||
# infer certain values from the environment variables. | ||
# Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID. Except for OPEN_AI_BASE_URL which has to be the deepseek url | ||
|
||
# Pass the entire config to the OpenAI client constructor | ||
self.client = openai.OpenAI(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
# Any exception raised by OpenAI will be returned to the caller. | ||
# Maybe we should catch them and raise a custom LLMError. | ||
return self.client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
**kwargs # Pass any additional arguments to the OpenAI API | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# DeepSeek | ||
|
||
To use DeepSeek with `aisuite`, you’ll need an [DeepSeek account](https://platform.deepseek.com). After logging in, go to the [API Keys](https://platform.deepseek.com/api_keys) section in your account settings and generate a new key. Once you have your key, add it to your environment as follows: | ||
|
||
```shell | ||
export DEEPSEEK_API_KEY="your-deepseek-api-key" | ||
``` | ||
|
||
## Create a Chat Completion | ||
|
||
(Note: The DeepSeek uses an API format consistent with OpenAI, hence why we need to install OpenAI, there is no DeepSeek Library at least not for now) | ||
|
||
Install the `openai` Python client: | ||
|
||
Example with pip: | ||
```shell | ||
pip install openai | ||
``` | ||
|
||
Example with poetry: | ||
```shell | ||
poetry add openai | ||
``` | ||
|
||
In your code: | ||
```python | ||
import aisuite as ai | ||
client = ai.Client() | ||
|
||
provider = "deepseek" | ||
model_id = "deepseek-chat" | ||
|
||
messages = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "What’s the weather like in San Francisco?"}, | ||
] | ||
|
||
response = client.chat.completions.create( | ||
model=f"{provider}:{model_id}", | ||
messages=messages, | ||
) | ||
|
||
print(response.choices[0].message.content) | ||
``` | ||
|
||
Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from aisuite.providers.deepseek_provider import DeepseekProvider | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def set_api_key_env_var(monkeypatch): | ||
"""Fixture to set environment variables for tests.""" | ||
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-api-key") | ||
|
||
|
||
def test_groq_provider(): | ||
"""High-level test that the provider is initialized and chat completions are requested successfully.""" | ||
|
||
user_greeting = "Hello!" | ||
message_history = [{"role": "user", "content": user_greeting}] | ||
selected_model = "our-favorite-model" | ||
chosen_temperature = 0.75 | ||
response_text_content = "mocked-text-response-from-model" | ||
|
||
provider = DeepseekProvider() | ||
mock_response = MagicMock() | ||
mock_response.choices = [MagicMock()] | ||
mock_response.choices[0].message = MagicMock() | ||
mock_response.choices[0].message.content = response_text_content | ||
|
||
with patch.object( | ||
provider.client.chat.completions, | ||
"create", | ||
return_value=mock_response, | ||
) as mock_create: | ||
response = provider.chat_completions_create( | ||
messages=message_history, | ||
model=selected_model, | ||
temperature=chosen_temperature, | ||
) | ||
|
||
mock_create.assert_called_with( | ||
messages=message_history, | ||
model=selected_model, | ||
temperature=chosen_temperature, | ||
) | ||
|
||
assert response.choices[0].message.content == response_text_content |