Skip to content

Commit

Permalink
update tongyi provider
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangshulin committed Jan 28, 2025
1 parent cdf9ff4 commit 4f8ed4b
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 39 deletions.
3 changes: 3 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ XAI_API_KEY=

# Sambanova
SAMBANOVA_API_KEY=

# TONGYI
TONGYI_API_KEY=
54 changes: 25 additions & 29 deletions aisuite/providers/tongyi_provider.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,33 @@
import os
import dashscope
from aisuite.provider import Provider
from aisuite.framework import ChatCompletionResponse
import openai

from aisuite.provider import Provider, LLMError

class TongyiProvider(Provider):
"""TongyiProvider is a class that provides an interface to the Tongyi's model."""

class TongyiProvider(Provider):
def __init__(self, **config):
self.api_key = config.get("api_key") or os.getenv("DASHSCOPE_API_KEY")

if not self.api_key:
raise EnvironmentError(
"Dashscope API key is missing. Please provide it in the config or set the DASHSCOPE_API_KEY environment variable."
"""
Initialize the Tongyi provider with the given configuration.
Pass the entire configuration dictionary to the Tongyi client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("TONGYI_API_KEY"))
config["base_url"] = "https://dashscope.aliyuncs.com/compatible-mode/v1"

if not config["api_key"]:
raise ValueError(
"Tongyi API key is missing. Please provide it in the config or set the TONGYI_API_KEY environment variable."
)

def chat_completions_create(self, model, messages, **kwargs):
"""Send a chat completion request to the Tongyi's model."""

response = dashscope.Generation.call(
api_key=self.api_key,
model=model,
messages=messages,
result_format="message",
**kwargs
)
return self.normalize_response(response)

def normalize_response(self, response):
"""Normalize the response from Dashscope to match OpenAI's response format."""
self.client = openai.OpenAI(**config)

openai_response = ChatCompletionResponse()
openai_response.choices[0].message.content = response["output"]["choices"][0][
"message"
].get("content")
return openai_response
def chat_completions_create(self, model, messages, **kwargs):
try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs, # Pass any additional arguments to the Tongyi API
)
return response
except Exception as e:
raise LLMError(f"An error occurred: {e}")
8 changes: 8 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def provider_configs():
"nebius": {
"api_key": "nebius-api-key",
},
"tongyi": {
"api_key": "tongyi-api-key",
},
}


Expand Down Expand Up @@ -87,6 +90,11 @@ def provider_configs():
"nebius",
"nebius-model",
),
(
"aisuite.providers.tongyi_provider.TongyiProvider.chat_completions_create",
"tongyi",
"tongyi-model",
),
],
)
def test_client_chat_completions(
Expand Down
1 change: 1 addition & 0 deletions tests/client/test_prerelease.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_test_models() -> List[str]:
"mistral:open-mistral-7b",
"openai:gpt-3.5-turbo",
"cohere:command-r-plus-08-2024",
"tongyi:qwen-plus",
]


Expand Down
20 changes: 10 additions & 10 deletions tests/providers/test_tongyi_provider.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
from unittest.mock import MagicMock, patch
import pytest
import dashscope

from aisuite.providers.tongyi_provider import TongyiProvider


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-api-key")
monkeypatch.setenv("TONGYI_API_KEY", "test-api-key")


def test_tongyi_provider():
def test_groq_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "qwen-plus"
chosen_temperature = 0
chosen_temperature = 0.8
response_text_content = "mocked-text-response-from-model"

provider = TongyiProvider()
mock_response = MagicMock()
mock_response = {
"output": {"choices": [{"message": {"content": response_text_content}}]}
}
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = response_text_content

with patch.object(
dashscope.Generation, "call", return_value=mock_response
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
Expand All @@ -35,11 +37,9 @@ def test_tongyi_provider():
)

mock_create.assert_called_with(
api_key=provider.api_key,
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
result_format="message",
)

assert response.choices[0].message.content == response_text_content

0 comments on commit 4f8ed4b

Please sign in to comment.