-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add deepseek r1 client for integration #327
Merged
liyin2015
merged 10 commits into
SylphAI-Inc:main
from
phi-jkim:integrate-deepseek-r1-client
Jan 29, 2025
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
59250b6
Add integration for deepseek-r1
phi-jkim 539d687
Merge pull request #306 from SylphAI-Inc/li
Sylph-AI b0bcbcf
Add deepseek-r1 integration using openai client
phi-jkim 4f5c6d6
Merge remote-tracking branch 'upstream/main' into integrate-deepseek-…
phi-jkim d3ed67c
Reset
phi-jkim 23ac1a0
merge
phi-jkim 56551cc
reset
phi-jkim 006e5d2
Remove benchmarks and use_cases from commit
phi-jkim ee8456c
Modify __init__.py for model_client
phi-jkim dda8086
Remove deepseek from lazy_import
phi-jkim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
adalflow/adalflow/components/model_client/deepseek_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import logging | ||
import backoff | ||
from typing import ( | ||
Dict, | ||
Sequence, | ||
Optional, | ||
List, | ||
Any, | ||
TypeVar, | ||
Callable, | ||
Literal, | ||
) | ||
|
||
from adalflow.utils.lazy_import import safe_import, OptionalPackages | ||
from adalflow.components.model_client.openai_client import OpenAIClient | ||
from openai.types import Completion | ||
|
||
openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) | ||
|
||
class DeepSeekClient(OpenAIClient): | ||
""" | ||
A component wrapper for the DeepSeek API client. | ||
|
||
DeepSeek's API is compatible with OpenAI's API, making it possible to use OpenAI SDKs | ||
or OpenAI-compatible software with DeepSeek by adjusting the API base URL. | ||
|
||
This client extends `OpenAIClient` but modifies the default `base_url` to use DeepSeek's API. | ||
|
||
Documentation: https://api-docs.deepseek.com/guides/reasoning_model | ||
|
||
Args: | ||
api_key (Optional[str], optional): DeepSeek API key. Defaults to `None`. | ||
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse API responses. | ||
input_type (Literal["text", "messages"], optional): Defines how input is handled. Defaults to `"text"`. | ||
base_url (str, optional): API base URL, defaults to `"https://api.deepseek.com/v1/"`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
chat_completion_parser: Callable[[Completion], Any] = None, | ||
input_type: Literal["text", "messages"] = "messages", | ||
base_url: str = "https://api.deepseek.com/v1/", | ||
env_api_key_name: str = "DEEPSEEK_API_KEY" | ||
): | ||
"""Initializes DeepSeek API client with the correct base URL. The input_type is set to "messages" by default to be compatible with DeepSeek reasoner. | ||
""" | ||
super().__init__(api_key=api_key, chat_completion_parser=chat_completion_parser, input_type=input_type, base_url=base_url, env_api_key_name=env_api_key_name) | ||
|
||
# Example usage: | ||
if __name__ == "__main__": | ||
from adalflow.core import Generator | ||
from adalflow.utils import setup_env, get_logger | ||
|
||
log = get_logger(level="DEBUG") | ||
|
||
prompt_kwargs = {"input_str": "What is the meaning of life?"} | ||
|
||
setup_env() | ||
|
||
gen = Generator( | ||
model_client=DeepSeekClient(), | ||
model_kwargs={"model": "deepseek-reasoner", "stream": True}, | ||
) | ||
|
||
gen_response = gen(prompt_kwargs) | ||
print(f"gen_response: {gen_response}") | ||
|
||
for genout in gen_response.data: | ||
print(f"genout: {genout}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,6 +64,23 @@ def get_first_message_content(completion: ChatCompletion) -> str: | |
# def _get_chat_completion_usage(completion: ChatCompletion) -> OpenAICompletionUsage: | ||
# return completion.usage | ||
|
||
# A simple heuristic to estimate token count for estimating number of tokens in a Streaming response | ||
def estimate_token_count(text: str) -> int: | ||
""" | ||
Estimate the token count of a given text. | ||
|
||
Args: | ||
text (str): The text to estimate token count for. | ||
|
||
Returns: | ||
int: Estimated token count. | ||
""" | ||
# Split the text into tokens using spaces as a simple heuristic | ||
tokens = text.split() | ||
|
||
# Return the number of tokens | ||
return len(tokens) | ||
|
||
|
||
def parse_stream_response(completion: ChatCompletionChunk) -> str: | ||
r"""Parse the response of the stream API.""" | ||
|
@@ -101,72 +118,83 @@ def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]: | |
class OpenAIClient(ModelClient): | ||
__doc__ = r"""A component wrapper for the OpenAI API client. | ||
|
||
Support both embedding and chat completion API, including multimodal capabilities. | ||
Supports both embedding and chat completion APIs, including multimodal capabilities. | ||
|
||
Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client. | ||
(2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project. | ||
Users can: | ||
1. Simplify use of ``Embedder`` and ``Generator`` components by passing `OpenAIClient()` as the `model_client`. | ||
2. Use this as a reference to create their own API client or extend this class by copying and modifying the code. | ||
|
||
Note: | ||
We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API. | ||
We do not know how OpenAI is doing the formating or what prompt they have added. | ||
Instead | ||
- use :ref:`OutputParser<components-output_parsers>` for response parsing and formating. | ||
We recommend avoiding `response_format` to enforce output data type or `tools` and `tool_choice` in `model_kwargs` when calling the API. | ||
OpenAI's internal formatting and added prompts are unknown. Instead: | ||
- Use :ref:`OutputParser<components-output_parsers>` for response parsing and formatting. | ||
|
||
For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them. | ||
The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini). | ||
For multimodal inputs, provide images in `model_kwargs["images"]` as a path, URL, or list of them. | ||
The model must support vision capabilities (e.g., `gpt-4o`, `gpt-4o-mini`, `o1`, `o1-mini`). | ||
|
||
For image generation, use model_type=ModelType.IMAGE_GENERATION and provide: | ||
- model: "dall-e-3" or "dall-e-2" | ||
For image generation, use `model_type=ModelType.IMAGE_GENERATION` and provide: | ||
- model: `"dall-e-3"` or `"dall-e-2"` | ||
- prompt: Text description of the image to generate | ||
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 | ||
- quality: "standard" or "hd" (DALL-E 3 only) | ||
- size: `"1024x1024"`, `"1024x1792"`, or `"1792x1024"` for DALL-E 3; `"256x256"`, `"512x512"`, or `"1024x1024"` for DALL-E 2 | ||
- quality: `"standard"` or `"hd"` (DALL-E 3 only) | ||
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2) | ||
- response_format: "url" or "b64_json" | ||
- response_format: `"url"` or `"b64_json"` | ||
|
||
Args: | ||
api_key (Optional[str], optional): OpenAI API key. Defaults to None. | ||
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. | ||
Default is `get_first_message_content`. | ||
|
||
api_key (Optional[str], optional): OpenAI API key. Defaults to `None`. | ||
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion into a `str`. Defaults to `None`. | ||
The default parser is `get_first_message_content`. | ||
base_url (str): The API base URL to use when initializing the client. | ||
Defaults to `"https://api.openai.com"`, but can be customized for third-party API providers or self-hosted models. | ||
env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. | ||
|
||
References: | ||
- Embeddings models: https://platform.openai.com/docs/guides/embeddings | ||
- Chat models: https://platform.openai.com/docs/guides/text-generation | ||
- Vision models: https://platform.openai.com/docs/guides/vision | ||
- Image models: https://platform.openai.com/docs/guides/images | ||
- OpenAI docs: https://platform.openai.com/docs/introduction | ||
- OpenAI API Overview: https://platform.openai.com/docs/introduction | ||
- Embeddings Guide: https://platform.openai.com/docs/guides/embeddings | ||
- Chat Completion Models: https://platform.openai.com/docs/guides/text-generation | ||
- Vision Models: https://platform.openai.com/docs/guides/vision | ||
- Image Generation: https://platform.openai.com/docs/guides/images | ||
""" | ||
|
||
|
||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
chat_completion_parser: Callable[[Completion], Any] = None, | ||
input_type: Literal["text", "messages"] = "text", | ||
base_url: str = "https://api.openai.com/v1/", | ||
env_api_key_name: str = "OPENAI_API_KEY" | ||
): | ||
r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument. | ||
|
||
Args: | ||
api_key (Optional[str], optional): OpenAI API key. Defaults to None. | ||
base_url (str): The API base URL to use when initializing the client. | ||
env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. | ||
""" | ||
super().__init__() | ||
self._api_key = api_key | ||
self._env_api_key_name = env_api_key_name | ||
self.base_url = base_url | ||
self.sync_client = self.init_sync_client() | ||
self.async_client = None # only initialize if the async call is called | ||
self.chat_completion_parser = ( | ||
chat_completion_parser or get_first_message_content | ||
) | ||
self._input_type = input_type | ||
self._api_kwargs = {} # add api kwargs when the OpenAI Client is called | ||
|
||
def init_sync_client(self): | ||
api_key = self._api_key or os.getenv("OPENAI_API_KEY") | ||
api_key = self._api_key or os.getenv(self._env_api_key_name) | ||
if not api_key: | ||
raise ValueError("Environment variable OPENAI_API_KEY must be set") | ||
return OpenAI(api_key=api_key) | ||
raise ValueError(f"Environment variable {self._env_api_key_name} must be set") | ||
return OpenAI(api_key=api_key, base_url=self.base_url) | ||
|
||
def init_async_client(self): | ||
api_key = self._api_key or os.getenv("OPENAI_API_KEY") | ||
api_key = self._api_key or os.getenv(self._env_api_key_name) | ||
if not api_key: | ||
raise ValueError("Environment variable OPENAI_API_KEY must be set") | ||
return AsyncOpenAI(api_key=api_key) | ||
raise ValueError(f"Environment variable {self._env_api_key_name} must be set") | ||
return AsyncOpenAI(api_key=api_key, base_url=self.base_url) | ||
|
||
# def _parse_chat_completion(self, completion: ChatCompletion) -> "GeneratorOutput": | ||
# # TODO: raw output it is better to save the whole completion as a source of truth instead of just the message | ||
|
@@ -208,9 +236,7 @@ def track_completion_usage( | |
) | ||
return usage | ||
else: | ||
raise NotImplementedError( | ||
"streaming completion usage tracking is not implemented" | ||
) | ||
raise ValueError(f"Unsupported completion type: {type(completion)}") | ||
|
||
def parse_embedding_response( | ||
self, response: CreateEmbeddingResponse | ||
|
@@ -268,11 +294,19 @@ def convert_inputs_to_api_kwargs( | |
system_end_tag = "<END_OF_SYSTEM_PROMPT>" | ||
user_start_tag = "<START_OF_USER_PROMPT>" | ||
user_end_tag = "<END_OF_USER_PROMPT>" | ||
pattern = f"{system_start_tag}(.*?){system_end_tag}{user_start_tag}(.*?){user_end_tag}" | ||
|
||
# new regex pattern to ignore special characters such as \n | ||
pattern = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original regex expression was not able to parse correctly. I fixed the regex pattern using \s to ignore escape characters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pretty good! |
||
rf"{system_start_tag}\s*(.*?)\s*{system_end_tag}\s*" | ||
rf"{user_start_tag}\s*(.*?)\s*{user_end_tag}" | ||
) | ||
|
||
# Compile the regular expression | ||
regex = re.compile(pattern) | ||
|
||
# re.DOTALL is to allow . to match newline so that (.*?) does not match in a single line | ||
regex = re.compile(pattern, re.DOTALL) | ||
# Match the pattern | ||
match = regex.search(input) | ||
match = regex.match(input) | ||
system_prompt, input_str = None, None | ||
|
||
if match: | ||
|
@@ -328,6 +362,9 @@ def convert_inputs_to_api_kwargs( | |
final_model_kwargs["mask"] = self._encode_image(mask) | ||
else: | ||
raise ValueError(f"model_type {model_type} is not supported") | ||
|
||
print(f"final_model_kwargs: {final_model_kwargs}") | ||
|
||
return final_model_kwargs | ||
|
||
def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput: | ||
|
@@ -362,6 +399,7 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE | |
kwargs is the combined input and model_kwargs. Support streaming call. | ||
""" | ||
log.info(f"api_kwargs: {api_kwargs}") | ||
self._api_kwargs = api_kwargs | ||
if model_type == ModelType.EMBEDDER: | ||
return self.sync_client.embeddings.create(**api_kwargs) | ||
elif model_type == ModelType.LLM: | ||
|
@@ -403,6 +441,8 @@ async def acall( | |
""" | ||
kwargs is the combined input and model_kwargs | ||
""" | ||
# store the api kwargs in the client | ||
self._api_kwargs = api_kwargs | ||
if self.async_client is None: | ||
self.async_client = self.init_async_client() | ||
if model_type == ModelType.EMBEDDER: | ||
|
@@ -498,21 +538,21 @@ def _prepare_image_content( | |
|
||
|
||
# Example usage: | ||
# if __name__ == "__main__": | ||
# from adalflow.core import Generator | ||
# from adalflow.utils import setup_env, get_logger | ||
# | ||
# log = get_logger(level="DEBUG") | ||
# | ||
# setup_env() | ||
# prompt_kwargs = {"input_str": "What is the meaning of life?"} | ||
# | ||
# gen = Generator( | ||
# model_client=OpenAIClient(), | ||
# model_kwargs={"model": "gpt-3.5-turbo", "stream": True}, | ||
# ) | ||
# gen_response = gen(prompt_kwargs) | ||
# print(f"gen_response: {gen_response}") | ||
# | ||
# for genout in gen_response.data: | ||
# print(f"genout: {genout}") | ||
if __name__ == "__main__": | ||
from adalflow.core import Generator | ||
from adalflow.utils import setup_env, get_logger | ||
|
||
log = get_logger(level="DEBUG") | ||
|
||
setup_env() | ||
prompt_kwargs = {"input_str": "What is the meaning of life?"} | ||
|
||
gen = Generator( | ||
model_client=OpenAIClient(), | ||
model_kwargs={"model": "gpt-3.5-turbo", "stream": True}, | ||
) | ||
gen_response = gen(prompt_kwargs) | ||
print(f"gen_response: {gen_response}") | ||
|
||
for genout in gen_response.data: | ||
print(f"genout: {genout}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import unittest | ||
from unittest.mock import patch, Mock | ||
import os | ||
|
||
from openai import Stream | ||
from openai.types import CompletionUsage | ||
from openai.types.chat import ChatCompletion, ChatCompletionChunk | ||
from adalflow.core.types import ModelType, GeneratorOutput | ||
from adalflow.components.model_client.deepseek_client import DeepSeekClient | ||
from unittest.mock import AsyncMock | ||
|
||
def getenv_side_effect(key): | ||
env_vars = {"DEEPSEEK_API_KEY": "fake_api_key"} | ||
return env_vars.get(key, None) | ||
|
||
class TestDeepSeekClient(unittest.TestCase): | ||
def setUp(self): | ||
self.client = DeepSeekClient(api_key="fake_api_key") | ||
|
||
def test_deepseek_init(self): | ||
self.assertEqual(self.client.base_url, "https://api.deepseek.com/v1/") | ||
self.assertEqual(self.client._input_type, "messages") | ||
self.assertEqual(self.client._env_api_key_name, "DEEPSEEK_API_KEY") | ||
|
||
# mock os.getenv(self._env_api_key_name) with getenv_side_effect | ||
@patch("os.getenv") | ||
def test_deepseek_init_sync_client(self, mock_os_getenv): | ||
mock_os_getenv.return_value = "fake_api_key" | ||
self.client.init_sync_client() | ||
self.assertEqual(self.client.sync_client.api_key, "fake_api_key") | ||
self.assertEqual(self.client.sync_client.base_url, "https://api.deepseek.com/v1/") | ||
|
||
@patch("os.getenv") | ||
def test_deepseek_init_async_client(self, mock_os_getenv): | ||
mock_os_getenv.return_value = "fake_api_key" | ||
self.client.async_client = self.client.init_async_client() | ||
self.assertEqual(self.client.async_client.api_key, "fake_api_key") | ||
self.assertEqual(self.client.async_client.base_url, "https://api.deepseek.com/v1/") | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait for the next version on not rewrite everything as they only use openai package
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are instead inheriting the openai client after making some modifications!