Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deepseek r1 client for integration #327

Merged
merged 10 commits into from
Jan 29, 2025
2 changes: 2 additions & 0 deletions adalflow/adalflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
AnthropicAPIClient,
CohereAPIClient,
BedrockAPIClient,
DeepSeekClient,
)

# data pipeline
Expand Down Expand Up @@ -130,6 +131,7 @@
"OpenAIClient",
"GoogleGenAIClient",
"GroqAPIClient",
"DeepSeekClient",
"OllamaClient",
"TransformersClient",
"AnthropicAPIClient",
Expand Down
6 changes: 6 additions & 0 deletions adalflow/adalflow/components/model_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
"adalflow.components.model_client.openai_client.OpenAIClient",
OptionalPackages.OPENAI,
)

DeepSeekClient = LazyImport(
"adalflow.components.model_client.deepseek_client.DeepSeekClient",
None
)

GoogleGenAIClient = LazyImport(
"adalflow.components.model_client.google_client.GoogleGenAIClient",
OptionalPackages.GOOGLE_GENERATIVEAI,
Expand Down
72 changes: 72 additions & 0 deletions adalflow/adalflow/components/model_client/deepseek_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait for the next version on not rewrite everything as they only use openai package

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are instead inheriting the openai client after making some modifications!

import logging
import backoff
from typing import (
Dict,
Sequence,
Optional,
List,
Any,
TypeVar,
Callable,
Literal,
)

from adalflow.utils.lazy_import import safe_import, OptionalPackages
from adalflow.components.model_client.openai_client import OpenAIClient
from openai.types import Completion

openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])

class DeepSeekClient(OpenAIClient):
"""
A component wrapper for the DeepSeek API client.

DeepSeek's API is compatible with OpenAI's API, making it possible to use OpenAI SDKs
or OpenAI-compatible software with DeepSeek by adjusting the API base URL.

This client extends `OpenAIClient` but modifies the default `base_url` to use DeepSeek's API.

Documentation: https://api-docs.deepseek.com/guides/reasoning_model

Args:
api_key (Optional[str], optional): DeepSeek API key. Defaults to `None`.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse API responses.
input_type (Literal["text", "messages"], optional): Defines how input is handled. Defaults to `"text"`.
base_url (str, optional): API base URL, defaults to `"https://api.deepseek.com/v1/"`.
"""

def __init__(
self,
api_key: Optional[str] = None,
chat_completion_parser: Callable[[Completion], Any] = None,
input_type: Literal["text", "messages"] = "messages",
base_url: str = "https://api.deepseek.com/v1/",
env_api_key_name: str = "DEEPSEEK_API_KEY"
):
"""Initializes DeepSeek API client with the correct base URL. The input_type is set to "messages" by default to be compatible with DeepSeek reasoner.
"""
super().__init__(api_key=api_key, chat_completion_parser=chat_completion_parser, input_type=input_type, base_url=base_url, env_api_key_name=env_api_key_name)

# Example usage:
if __name__ == "__main__":
from adalflow.core import Generator
from adalflow.utils import setup_env, get_logger

log = get_logger(level="DEBUG")

prompt_kwargs = {"input_str": "What is the meaning of life?"}

setup_env()

gen = Generator(
model_client=DeepSeekClient(),
model_kwargs={"model": "deepseek-reasoner", "stream": True},
)

gen_response = gen(prompt_kwargs)
print(f"gen_response: {gen_response}")

for genout in gen_response.data:
print(f"genout: {genout}")

146 changes: 93 additions & 53 deletions adalflow/adalflow/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ def get_first_message_content(completion: ChatCompletion) -> str:
# def _get_chat_completion_usage(completion: ChatCompletion) -> OpenAICompletionUsage:
# return completion.usage

# A simple heuristic to estimate token count for estimating number of tokens in a Streaming response
def estimate_token_count(text: str) -> int:
"""
Estimate the token count of a given text.

Args:
text (str): The text to estimate token count for.

Returns:
int: Estimated token count.
"""
# Split the text into tokens using spaces as a simple heuristic
tokens = text.split()

# Return the number of tokens
return len(tokens)


def parse_stream_response(completion: ChatCompletionChunk) -> str:
r"""Parse the response of the stream API."""
Expand Down Expand Up @@ -101,72 +118,83 @@ def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]:
class OpenAIClient(ModelClient):
__doc__ = r"""A component wrapper for the OpenAI API client.

Support both embedding and chat completion API, including multimodal capabilities.
Supports both embedding and chat completion APIs, including multimodal capabilities.

Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client.
(2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project.
Users can:
1. Simplify use of ``Embedder`` and ``Generator`` components by passing `OpenAIClient()` as the `model_client`.
2. Use this as a reference to create their own API client or extend this class by copying and modifying the code.

Note:
We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API.
We do not know how OpenAI is doing the formating or what prompt they have added.
Instead
- use :ref:`OutputParser<components-output_parsers>` for response parsing and formating.
We recommend avoiding `response_format` to enforce output data type or `tools` and `tool_choice` in `model_kwargs` when calling the API.
OpenAI's internal formatting and added prompts are unknown. Instead:
- Use :ref:`OutputParser<components-output_parsers>` for response parsing and formatting.

For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them.
The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini).
For multimodal inputs, provide images in `model_kwargs["images"]` as a path, URL, or list of them.
The model must support vision capabilities (e.g., `gpt-4o`, `gpt-4o-mini`, `o1`, `o1-mini`).

For image generation, use model_type=ModelType.IMAGE_GENERATION and provide:
- model: "dall-e-3" or "dall-e-2"
For image generation, use `model_type=ModelType.IMAGE_GENERATION` and provide:
- model: `"dall-e-3"` or `"dall-e-2"`
- prompt: Text description of the image to generate
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
- quality: "standard" or "hd" (DALL-E 3 only)
- size: `"1024x1024"`, `"1024x1792"`, or `"1792x1024"` for DALL-E 3; `"256x256"`, `"512x512"`, or `"1024x1024"` for DALL-E 2
- quality: `"standard"` or `"hd"` (DALL-E 3 only)
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2)
- response_format: "url" or "b64_json"
- response_format: `"url"` or `"b64_json"`

Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
Default is `get_first_message_content`.

api_key (Optional[str], optional): OpenAI API key. Defaults to `None`.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion into a `str`. Defaults to `None`.
The default parser is `get_first_message_content`.
base_url (str): The API base URL to use when initializing the client.
Defaults to `"https://api.openai.com"`, but can be customized for third-party API providers or self-hosted models.
env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`.

References:
- Embeddings models: https://platform.openai.com/docs/guides/embeddings
- Chat models: https://platform.openai.com/docs/guides/text-generation
- Vision models: https://platform.openai.com/docs/guides/vision
- Image models: https://platform.openai.com/docs/guides/images
- OpenAI docs: https://platform.openai.com/docs/introduction
- OpenAI API Overview: https://platform.openai.com/docs/introduction
- Embeddings Guide: https://platform.openai.com/docs/guides/embeddings
- Chat Completion Models: https://platform.openai.com/docs/guides/text-generation
- Vision Models: https://platform.openai.com/docs/guides/vision
- Image Generation: https://platform.openai.com/docs/guides/images
"""


def __init__(
self,
api_key: Optional[str] = None,
chat_completion_parser: Callable[[Completion], Any] = None,
input_type: Literal["text", "messages"] = "text",
base_url: str = "https://api.openai.com/v1/",
env_api_key_name: str = "OPENAI_API_KEY"
):
r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.

Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
base_url (str): The API base URL to use when initializing the client.
env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`.
"""
super().__init__()
self._api_key = api_key
self._env_api_key_name = env_api_key_name
self.base_url = base_url
self.sync_client = self.init_sync_client()
self.async_client = None # only initialize if the async call is called
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)
self._input_type = input_type
self._api_kwargs = {} # add api kwargs when the OpenAI Client is called

def init_sync_client(self):
api_key = self._api_key or os.getenv("OPENAI_API_KEY")
api_key = self._api_key or os.getenv(self._env_api_key_name)
if not api_key:
raise ValueError("Environment variable OPENAI_API_KEY must be set")
return OpenAI(api_key=api_key)
raise ValueError(f"Environment variable {self._env_api_key_name} must be set")
return OpenAI(api_key=api_key, base_url=self.base_url)

def init_async_client(self):
api_key = self._api_key or os.getenv("OPENAI_API_KEY")
api_key = self._api_key or os.getenv(self._env_api_key_name)
if not api_key:
raise ValueError("Environment variable OPENAI_API_KEY must be set")
return AsyncOpenAI(api_key=api_key)
raise ValueError(f"Environment variable {self._env_api_key_name} must be set")
return AsyncOpenAI(api_key=api_key, base_url=self.base_url)

# def _parse_chat_completion(self, completion: ChatCompletion) -> "GeneratorOutput":
# # TODO: raw output it is better to save the whole completion as a source of truth instead of just the message
Expand Down Expand Up @@ -208,9 +236,7 @@ def track_completion_usage(
)
return usage
else:
raise NotImplementedError(
"streaming completion usage tracking is not implemented"
)
raise ValueError(f"Unsupported completion type: {type(completion)}")

def parse_embedding_response(
self, response: CreateEmbeddingResponse
Expand Down Expand Up @@ -268,11 +294,19 @@ def convert_inputs_to_api_kwargs(
system_end_tag = "<END_OF_SYSTEM_PROMPT>"
user_start_tag = "<START_OF_USER_PROMPT>"
user_end_tag = "<END_OF_USER_PROMPT>"
pattern = f"{system_start_tag}(.*?){system_end_tag}{user_start_tag}(.*?){user_end_tag}"

# new regex pattern to ignore special characters such as \n
pattern = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original regex expression was not able to parse correctly. I fixed the regex pattern using \s to ignore escape characters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty good!

rf"{system_start_tag}\s*(.*?)\s*{system_end_tag}\s*"
rf"{user_start_tag}\s*(.*?)\s*{user_end_tag}"
)

# Compile the regular expression
regex = re.compile(pattern)

# re.DOTALL is to allow . to match newline so that (.*?) does not match in a single line
regex = re.compile(pattern, re.DOTALL)
# Match the pattern
match = regex.search(input)
match = regex.match(input)
system_prompt, input_str = None, None

if match:
Expand Down Expand Up @@ -328,6 +362,9 @@ def convert_inputs_to_api_kwargs(
final_model_kwargs["mask"] = self._encode_image(mask)
else:
raise ValueError(f"model_type {model_type} is not supported")

print(f"final_model_kwargs: {final_model_kwargs}")

return final_model_kwargs

def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput:
Expand Down Expand Up @@ -362,6 +399,7 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
kwargs is the combined input and model_kwargs. Support streaming call.
"""
log.info(f"api_kwargs: {api_kwargs}")
self._api_kwargs = api_kwargs
if model_type == ModelType.EMBEDDER:
return self.sync_client.embeddings.create(**api_kwargs)
elif model_type == ModelType.LLM:
Expand Down Expand Up @@ -403,6 +441,8 @@ async def acall(
"""
kwargs is the combined input and model_kwargs
"""
# store the api kwargs in the client
self._api_kwargs = api_kwargs
if self.async_client is None:
self.async_client = self.init_async_client()
if model_type == ModelType.EMBEDDER:
Expand Down Expand Up @@ -498,21 +538,21 @@ def _prepare_image_content(


# Example usage:
# if __name__ == "__main__":
# from adalflow.core import Generator
# from adalflow.utils import setup_env, get_logger
#
# log = get_logger(level="DEBUG")
#
# setup_env()
# prompt_kwargs = {"input_str": "What is the meaning of life?"}
#
# gen = Generator(
# model_client=OpenAIClient(),
# model_kwargs={"model": "gpt-3.5-turbo", "stream": True},
# )
# gen_response = gen(prompt_kwargs)
# print(f"gen_response: {gen_response}")
#
# for genout in gen_response.data:
# print(f"genout: {genout}")
if __name__ == "__main__":
from adalflow.core import Generator
from adalflow.utils import setup_env, get_logger

log = get_logger(level="DEBUG")

setup_env()
prompt_kwargs = {"input_str": "What is the meaning of life?"}

gen = Generator(
model_client=OpenAIClient(),
model_kwargs={"model": "gpt-3.5-turbo", "stream": True},
)
gen_response = gen(prompt_kwargs)
print(f"gen_response: {gen_response}")

for genout in gen_response.data:
print(f"genout: {genout}")
41 changes: 41 additions & 0 deletions adalflow/tests/test_deepseek_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest
from unittest.mock import patch, Mock
import os

from openai import Stream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from adalflow.core.types import ModelType, GeneratorOutput
from adalflow.components.model_client.deepseek_client import DeepSeekClient
from unittest.mock import AsyncMock

def getenv_side_effect(key):
env_vars = {"DEEPSEEK_API_KEY": "fake_api_key"}
return env_vars.get(key, None)

class TestDeepSeekClient(unittest.TestCase):
def setUp(self):
self.client = DeepSeekClient(api_key="fake_api_key")

def test_deepseek_init(self):
self.assertEqual(self.client.base_url, "https://api.deepseek.com/v1/")
self.assertEqual(self.client._input_type, "messages")
self.assertEqual(self.client._env_api_key_name, "DEEPSEEK_API_KEY")

# mock os.getenv(self._env_api_key_name) with getenv_side_effect
@patch("os.getenv")
def test_deepseek_init_sync_client(self, mock_os_getenv):
mock_os_getenv.return_value = "fake_api_key"
self.client.init_sync_client()
self.assertEqual(self.client.sync_client.api_key, "fake_api_key")
self.assertEqual(self.client.sync_client.base_url, "https://api.deepseek.com/v1/")

@patch("os.getenv")
def test_deepseek_init_async_client(self, mock_os_getenv):
mock_os_getenv.return_value = "fake_api_key"
self.client.async_client = self.client.init_async_client()
self.assertEqual(self.client.async_client.api_key, "fake_api_key")
self.assertEqual(self.client.async_client.base_url, "https://api.deepseek.com/v1/")

if __name__ == "__main__":
unittest.main()
Loading
Loading