diff --git a/.env.sample b/.env.sample index cc933d9..2632f05 100644 --- a/.env.sample +++ b/.env.sample @@ -36,3 +36,6 @@ XAI_API_KEY= # Sambanova SAMBANOVA_API_KEY= + +# TONGYI +TONGYI_API_KEY= diff --git a/aisuite/providers/tongyi_provider.py b/aisuite/providers/tongyi_provider.py index c3e1841..a5949df 100644 --- a/aisuite/providers/tongyi_provider.py +++ b/aisuite/providers/tongyi_provider.py @@ -1,37 +1,33 @@ import os -import dashscope -from aisuite.provider import Provider -from aisuite.framework import ChatCompletionResponse +import openai +from aisuite.provider import Provider, LLMError -class TongyiProvider(Provider): - """TongyiProvider is a class that provides an interface to the Tongyi's model.""" +class TongyiProvider(Provider): def __init__(self, **config): - self.api_key = config.get("api_key") or os.getenv("DASHSCOPE_API_KEY") - - if not self.api_key: - raise EnvironmentError( - "Dashscope API key is missing. Please provide it in the config or set the DASHSCOPE_API_KEY environment variable." + """ + Initialize the Tongyi provider with the given configuration. + Pass the entire configuration dictionary to the Tongyi client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("TONGYI_API_KEY")) + config["base_url"] = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + if not config["api_key"]: + raise ValueError( + "Tongyi API key is missing. Please provide it in the config or set the TONGYI_API_KEY environment variable." ) - def chat_completions_create(self, model, messages, **kwargs): - """Send a chat completion request to the Tongyi's model.""" - - response = dashscope.Generation.call( - api_key=self.api_key, - model=model, - messages=messages, - result_format="message", - **kwargs - ) - return self.normalize_response(response) - - def normalize_response(self, response): - """Normalize the response from Dashscope to match OpenAI's response format.""" + self.client = openai.OpenAI(**config) - openai_response = ChatCompletionResponse() - openai_response.choices[0].message.content = response["output"]["choices"][0][ - "message" - ].get("content") - return openai_response + def chat_completions_create(self, model, messages, **kwargs): + try: + response = self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs, # Pass any additional arguments to the Tongyi API + ) + return response + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 04d1243..6620dcd 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -36,6 +36,9 @@ def provider_configs(): "nebius": { "api_key": "nebius-api-key", }, + "tongyi": { + "api_key": "tongyi-api-key", + }, } @@ -87,6 +90,11 @@ def provider_configs(): "nebius", "nebius-model", ), + ( + "aisuite.providers.tongyi_provider.TongyiProvider.chat_completions_create", + "tongyi", + "tongyi-model", + ), ], ) def test_client_chat_completions( diff --git a/tests/client/test_prerelease.py b/tests/client/test_prerelease.py index bb5f328..46b2261 100644 --- a/tests/client/test_prerelease.py +++ b/tests/client/test_prerelease.py @@ -23,6 +23,7 @@ def get_test_models() -> List[str]: "mistral:open-mistral-7b", "openai:gpt-3.5-turbo", "cohere:command-r-plus-08-2024", + "tongyi:qwen-plus", ] diff --git a/tests/providers/test_tongyi_provider.py b/tests/providers/test_tongyi_provider.py index dc1ecca..d88fcd8 100644 --- a/tests/providers/test_tongyi_provider.py +++ b/tests/providers/test_tongyi_provider.py @@ -1,32 +1,34 @@ from unittest.mock import MagicMock, patch import pytest -import dashscope + from aisuite.providers.tongyi_provider import TongyiProvider @pytest.fixture(autouse=True) def set_api_key_env_var(monkeypatch): """Fixture to set environment variables for tests.""" - monkeypatch.setenv("DASHSCOPE_API_KEY", "test-api-key") + monkeypatch.setenv("TONGYI_API_KEY", "test-api-key") -def test_tongyi_provider(): +def test_groq_provider(): """High-level test that the provider is initialized and chat completions are requested successfully.""" user_greeting = "Hello!" message_history = [{"role": "user", "content": user_greeting}] selected_model = "qwen-plus" - chosen_temperature = 0 + chosen_temperature = 0.8 response_text_content = "mocked-text-response-from-model" provider = TongyiProvider() mock_response = MagicMock() - mock_response = { - "output": {"choices": [{"message": {"content": response_text_content}}]} - } + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content with patch.object( - dashscope.Generation, "call", return_value=mock_response + provider.client.chat.completions, + "create", + return_value=mock_response, ) as mock_create: response = provider.chat_completions_create( messages=message_history, @@ -35,11 +37,9 @@ def test_tongyi_provider(): ) mock_create.assert_called_with( - api_key=provider.api_key, messages=message_history, model=selected_model, temperature=chosen_temperature, - result_format="message", ) assert response.choices[0].message.content == response_text_content