Skip to content

Commit

Permalink
Add unit and integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Apr 4, 2024
1 parent 3895c0f commit b4a0b5f
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class MistralChatAdapter(BedrockModelChatAdapter):
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
"{% endif %}"
"{{bos_token}}"
"{% for message in loop_messages %}"
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
Expand All @@ -304,7 +305,6 @@ class MistralChatAdapter(BedrockModelChatAdapter):

# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
ALLOWED_PARAMS: ClassVar[List[str]] = [
"anthropic_version",
"max_tokens",
"safe_prompt",
"random_seed",
Expand Down Expand Up @@ -346,7 +346,10 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[
default_params = {
"max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required
}

# replace stop_words from inference_kwargs with stop, as this is Mistral specific
stop_words = inference_kwargs.pop("stop_words", [])
if stop_words:
inference_kwargs["stop"] = stop_words
params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS)
body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
return body
Expand Down
48 changes: 48 additions & 0 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
AnthropicClaudeChatAdapter,
BedrockModelChatAdapter,
MetaLlama2ChatAdapter,
MistralChatAdapter,
)

KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"]
MISTRAL_MODELS = ["mistral.mistral-7b-instruct-v0:2", "mistral.mixtral-8x7b-instruct-v0:1"]


def test_to_dict(mock_boto3_session):
Expand Down Expand Up @@ -176,6 +178,52 @@ def test_prepare_body_with_custom_inference_params(self) -> None:
assert body == expected_body


class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MistralChatAdapter(generation_kwargs={})
prompt = "Hello, how are you?"
expected_body = {
"max_tokens": 512,
"prompt": "<s>[INST] Hello, how are you? [/INST]",
}

body = layer.prepare_body([ChatMessage.from_user(prompt)])

assert body == expected_body

def test_prepare_body_with_custom_inference_params(self) -> None:
layer = MistralChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4})
prompt = "Hello, how are you?"
expected_body = {
"prompt": "<s>[INST] Hello, how are you? [/INST]",
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.8,
}

body = layer.prepare_body([ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69)

assert body == expected_body

@pytest.mark.parametrize("model_name", MISTRAL_MODELS)
@pytest.mark.integration
def test_default_inference_params(self, model_name, chat_messages):
client = AmazonBedrockChatGenerator(model=model_name)
response = client.run(chat_messages)

assert "replies" in response, "Response does not contain 'replies' key"
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"

first_reply = replies[0]
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"


@pytest.fixture
def chat_messages():
messages = [
Expand Down
134 changes: 134 additions & 0 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BedrockModelAdapter,
CohereCommandAdapter,
MetaLlama2ChatAdapter,
MistralAdapter,
)


Expand Down Expand Up @@ -367,6 +368,139 @@ def test_get_stream_responses_empty(self) -> None:
stream_handler_mock.assert_not_called()


class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MistralAdapter(model_kwargs={}, max_length=99)
prompt = "Hello, how are you?"
expected_body = {"prompt": "<s>[INST] Hello, how are you? [/INST]", "max_tokens": 99, "stop": []}

body = layer.prepare_body(prompt)
assert body == expected_body

def test_prepare_body_with_custom_inference_params(self) -> None:
layer = MistralAdapter(model_kwargs={}, max_length=99)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "<s>[INST] Hello, how are you? [/INST]",
"max_tokens": 50,
"stop": ["CUSTOM_STOP"],
"temperature": 0.7,
"top_p": 0.8,
"top_k": 5,
}

body = layer.prepare_body(
prompt,
temperature=0.7,
top_p=0.8,
top_k=5,
max_tokens=50,
stop=["CUSTOM_STOP"],
unknown_arg="unknown_value",
)

assert body == expected_body

def test_prepare_body_with_model_kwargs(self) -> None:
layer = MistralAdapter(
model_kwargs={
"temperature": 0.7,
"top_p": 0.8,
"top_k": 5,
"max_tokens": 50,
"stop": ["CUSTOM_STOP"],
"unknown_arg": "unknown_value",
},
max_length=99,
)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "<s>[INST] Hello, how are you? [/INST]",
"max_tokens": 50,
"stop": ["CUSTOM_STOP"],
"temperature": 0.7,
"top_p": 0.8,
"top_k": 5,
}

body = layer.prepare_body(prompt)

assert body == expected_body

def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None:
layer = MistralAdapter(
model_kwargs={
"temperature": 0.6,
"top_p": 0.7,
"top_k": 4,
"max_tokens": 49,
"stop": ["CUSTOM_STOP_MODEL_KWARGS"],
},
max_length=99,
)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "<s>[INST] Hello, how are you? [/INST]",
"max_tokens": 50,
"stop": ["CUSTOM_STOP_MODEL_KWARGS"],
"temperature": 0.7,
"top_p": 0.8,
"top_k": 5,
}

body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens=50)

assert body == expected_body

def test_get_responses(self) -> None:
adapter = MistralAdapter(model_kwargs={}, max_length=99)
response_body = {"outputs": [{"text": "This is a single response."}]}
expected_responses = ["This is a single response."]
assert adapter.get_responses(response_body) == expected_responses

def test_get_stream_responses(self) -> None:
stream_mock = MagicMock()
stream_handler_mock = MagicMock()

stream_mock.__iter__.return_value = [
{"chunk": {"bytes": b'{"outputs": [{"text": " This"}]}'}},
{"chunk": {"bytes": b'{"outputs": [{"text": " is"}]}'}},
{"chunk": {"bytes": b'{"outputs": [{"text": " a"}]}'}},
{"chunk": {"bytes": b'{"outputs": [{"text": " single"}]}'}},
{"chunk": {"bytes": b'{"outputs": [{"text": " response."}]}'}},
]

stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received

adapter = MistralAdapter(model_kwargs={}, max_length=99)
expected_responses = ["This is a single response."]
assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses

stream_handler_mock.assert_has_calls(
[
call(" This", event_data={"outputs": [{"text": " This"}]}),
call(" is", event_data={"outputs": [{"text": " is"}]}),
call(" a", event_data={"outputs": [{"text": " a"}]}),
call(" single", event_data={"outputs": [{"text": " single"}]}),
call(" response.", event_data={"outputs": [{"text": " response."}]}),
]
)

def test_get_stream_responses_empty(self) -> None:
stream_mock = MagicMock()
stream_handler_mock = MagicMock()

stream_mock.__iter__.return_value = []

stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received

adapter = MistralAdapter(model_kwargs={}, max_length=99)
expected_responses = [""]
assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses

stream_handler_mock.assert_not_called()


class TestCohereCommandAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = CohereCommandAdapter(model_kwargs={}, max_length=99)
Expand Down

0 comments on commit b4a0b5f

Please sign in to comment.