Skip to content

Commit

Permalink
Merge pull request #560 from valory-xyz/feature/protocol_generator_te…
Browse files Browse the repository at this point in the history
…sts_generation

[WIP] Feature/protocol generator tests generation
  • Loading branch information
DavidMinarsch authored Feb 1, 2023
2 parents 2d0b503 + 4f2dba7 commit 0816fd9
Show file tree
Hide file tree
Showing 92 changed files with 2,802 additions and 201 deletions.
8 changes: 5 additions & 3 deletions aea/cli/generate_all_protocols.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#
# Copyright 2021-2022 Valory AG
# Copyright 2021-2023 Valory AG
# Copyright 2018-2019 Fetch.AI Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -38,6 +38,7 @@
import subprocess # nosec
import sys
import tempfile
from distutils.dir_util import copy_tree # pylint: disable=deprecated-module
from pathlib import Path
from typing import Any, List, Tuple, cast

Expand Down Expand Up @@ -174,8 +175,9 @@ def _fix_generated_protocol(package_path: Path) -> None:
tests_module = package_path / AEA_TEST_DIRNAME
if tests_module.is_dir():
log(f"Restore original `tests` directory in {package_path}")
shutil.copytree(
tests_module, Path(PROTOCOLS, package_path.name, AEA_TEST_DIRNAME)
copy_tree(
str(tests_module),
str(Path(PROTOCOLS, package_path.name, AEA_TEST_DIRNAME)),
)


Expand Down
2 changes: 1 addition & 1 deletion aea/configurations/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
_COSMOS_IDENTIFIER = "cosmos"
SIGNING_PROTOCOL = "open_aea/signing:latest"
SIGNING_PROTOCOL_WITH_HASH = (
"open_aea/signing:1.0.0:bafybeiclsbgrviyxbmi2vex5ze3dhr7ywohrqedebx26jozayxvroqtegq"
"open_aea/signing:1.0.0:bafybeibqlfmikg5hk4phzak6gqzhpkt6akckx7xppbp53mvwt6r73h7tk4"
)
DEFAULT_LEDGER = _ETHEREUM_IDENTIFIER
PRIVATE_KEY_PATH_SCHEMA = "{}_private_key.txt"
Expand Down
326 changes: 326 additions & 0 deletions aea/protocols/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,8 @@ def generate_full_mode(self, language: str) -> Optional[str]:
self._serialization_class_str(),
)

self._generate_tests()

# Run black formatting
try_run_black_formatting(self.path_to_generated_protocol_package)

Expand All @@ -2146,6 +2148,330 @@ def generate_full_mode(self, language: str) -> Optional[str]:
full_mode_output = incomplete_generation_warning_msg
return full_mode_output

def _generate_tests(self) -> None:
tests_dir = str(Path(self.path_to_generated_protocol_package) / "tests")
os.makedirs(tests_dir, exist_ok=True)

TESTS_MESSAGES_DOT_PY_FILE_NAME = (
f"test_{self.protocol_specification.name}_messages.py"
)
_create_protocol_file(
tests_dir,
TESTS_MESSAGES_DOT_PY_FILE_NAME,
self._test_messages_file_str(),
)

TESTS_DIALOGUES_DOT_PY_FILE_NAME = (
f"test_{self.protocol_specification.name}_dialogues.py"
)
_create_protocol_file(
tests_dir,
TESTS_DIALOGUES_DOT_PY_FILE_NAME,
self._test_dialogues_file_str(),
)

def _test_messages_file_str(self) -> str:
"""
Produce the content of the test_messages.py.
:return: file content
"""
self._change_indent(0, "s")

# Header
cls_str = _copyright_header_str(self.protocol_specification.author) + "\n"

# Module docstring
cls_str += (
self.indent
+ '"""Test messages module for {} protocol."""\n\n'.format(
self.protocol_specification.name
)
)

cls_str += f"# pylint: disable={','.join(PYLINT_DISABLE_SERIALIZATION_PY)}\n"

# Imports
cls_str += self.indent + "from typing import List\n\n"
cls_str += (
self.indent
+ "from aea.test_tools.test_protocol import BaseProtocolMessagesTestCase\n"
)

for custom_type in self.spec.all_custom_types:
cls_str += (
self.indent
+ "from {}.custom_types import (\n {},\n)\n".format(
self.dotted_path_to_protocol_package,
custom_type,
)
)
cls_str += self.indent + "from {}.message import (\n {}Message,\n)\n".format(
self.dotted_path_to_protocol_package,
self.protocol_specification_in_camel_case,
)

# Class Header
cls_str += (
self.indent
+ "\n\nclass TestMessage{}(BaseProtocolMessagesTestCase):\n".format(
self.protocol_specification_in_camel_case,
)
)
self._change_indent(1)
cls_str += (
self.indent
+ '"""Test for the \'{}\' protocol message."""\n\n'.format(
self.protocol_specification.name,
)
)

msg_class = f"{self.protocol_specification_in_camel_case}Message"
cls_str += self.indent + "\n"
cls_str += self.indent + f"MESSAGE_CLASS = {msg_class}\n\n"
cls_str += (
self.indent
+ f"def build_messages(self) -> List[{msg_class}]: # type: ignore[override]\n"
)
self._change_indent(1)
cls_str += self.indent + '"""Build the messages to be used for testing."""\n'

cls_str += self.indent + "return [\n"

for performative, content in self.spec.speech_acts.items():
cls_str += (
self.indent
+ f"""
{msg_class}(
performative={msg_class}.Performative.{performative.upper()}\n,
"""
)
for content_name, content_type in content.items():
if content_type in self.spec.all_custom_types:
cls_str += (
self.indent
+ f"{content_name} = {content_type}(), # check it please!\n"
)
else:
cls_str += (
self.indent
+ f"{content_name} = {self._make_type_value(content_type)},\n"
)

cls_str += self.indent + "),\n"

cls_str += self.indent + "]\n"

self._change_indent(-1)

cls_str += (
self.indent
+ f"def build_inconsistent(self) -> List[{msg_class}]: # type: ignore[override]\n"
)

self._change_indent(1)
cls_str += (
self.indent + '"""Build inconsistent messages to be used for testing."""\n'
)

cls_str += self.indent + "return [\n"

for performative, content in self.spec.speech_acts.items():
if len(content) == 0:
# no content to skip
continue

cls_str += (
self.indent
+ f"""
{msg_class}(
performative={msg_class}.Performative.{performative.upper()}\n,
"""
)
idx = 0
for content_name, content_type in content.items():
idx += 1
if idx == 1:
cls_str += self.indent + f"# skip content: {content_name}\n"
continue
if content_type in self.spec.all_custom_types:
cls_str += (
self.indent
+ f"{content_name} = {content_type}(), # check it please!\n"
)
else:
cls_str += (
self.indent
+ f"{content_name} = {self._make_type_value(content_type)},\n"
)

cls_str += self.indent + "),\n"

cls_str += self.indent + "]\n"

return cls_str

def _test_dialogues_file_str(self) -> str:
"""
Produce the content of the test_dialogues.py.
:return: file content
"""
self._change_indent(0, "s")

# Header
cls_str = _copyright_header_str(self.protocol_specification.author) + "\n"

# Module docstring
cls_str += (
self.indent
+ '"""Test dialogues module for {} protocol."""\n\n'.format(
self.protocol_specification.name
)
)

cls_str += f"# pylint: disable={','.join(PYLINT_DISABLE_SERIALIZATION_PY)}\n"

# Imports
cls_str += (
self.indent
+ "from aea.test_tools.test_protocol import BaseProtocolDialoguesTestCase\n"
)

msg_class = f"{self.protocol_specification_in_camel_case}Message"
performative = self.spec.initial_performatives[0]
content = self.spec.speech_acts[performative.lower()]
role = self.spec.roles[0]

for custom_type in self.spec.all_custom_types:
if custom_type not in content.values():
# skip unused custom types
continue
cls_str += (
self.indent
+ "from {}.custom_types import (\n {},\n)\n".format(
self.dotted_path_to_protocol_package,
custom_type,
)
)
cls_str += self.indent + "from {}.message import (\n {}Message,\n)\n".format(
self.dotted_path_to_protocol_package,
self.protocol_specification_in_camel_case,
)

cls_str += (
self.indent
+ "from {}.dialogues import (\n {}Dialogue,\n)\n".format(
self.dotted_path_to_protocol_package,
self.protocol_specification_in_camel_case,
)
)
cls_str += (
self.indent
+ "from {}.dialogues import (\n {}Dialogues,\n)\n".format(
self.dotted_path_to_protocol_package,
self.protocol_specification_in_camel_case,
)
)

# Class Header
cls_str += (
self.indent
+ "\n\nclass TestDialogues{}(BaseProtocolDialoguesTestCase):\n".format(
self.protocol_specification_in_camel_case,
)
)
self._change_indent(1)
cls_str += (
self.indent
+ '"""Test for the \'{}\' protocol dialogues."""\n\n'.format(
self.protocol_specification.name,
)
)

cls_str += self.indent + "\n"
cls_str += self.indent + f"MESSAGE_CLASS = {msg_class}\n\n"
cls_str += (
self.indent
+ f"DIALOGUE_CLASS = {self.protocol_specification_in_camel_case}Dialogue\n\n"
)
cls_str += (
self.indent
+ f"DIALOGUES_CLASS = {self.protocol_specification_in_camel_case}Dialogues\n\n"
)
cls_str += (
self.indent
+ f"ROLE_FOR_THE_FIRST_MESSAGE = {self.protocol_specification_in_camel_case}Dialogue.Role.{role.upper()} # CHECK\n\n"
)

cls_str += self.indent + "def make_message_content(self) -> dict:\n"
self._change_indent(1)
cls_str += (
self.indent
+ '"""Make a dict with message contruction content for dialogues.create."""\n'
)

cls_str += self.indent + "return dict(\n"
cls_str += (
self.indent
+ f"performative={msg_class}.Performative.{performative.upper()},"
)

for content_name, content_type in content.items():
if content_type in self.spec.all_custom_types:
cls_str += (
self.indent
+ f"{content_name} = {content_type}(), # check it please!\n"
)
else:
cls_str += (
self.indent
+ f"{content_name} = {self._make_type_value(content_type)},\n"
)
cls_str += self.indent + ")\n"
return cls_str

def _make_type_value(self, content_type: str) -> str:
"""
Make a value of type definition.
:param content_type: str type definition
:returns: str value
"""
type_map = {
"bytes": 'b"some_bytes"',
"str": '"some str"',
"bool": "True",
"int": "12",
"float": "1.0",
}
if content_type in type_map:
return type_map[content_type]

if content_type.startswith("List"):
inner_type = content_type[5:-1]
return f"[{self._make_type_value(inner_type)}]"
elif content_type.startswith("FrozenSet"):
inner_type = content_type[10:-1]
return f"frozenset([{self._make_type_value(inner_type)}])"
elif content_type.startswith("Union"):
inner_type = content_type[6:-1].split(",")[0].strip()
return f"{self._make_type_value(inner_type)}"
elif content_type.startswith("Tuple"):
inner_type = content_type[6:-1].split(",")[0].strip()
return f"({self._make_type_value(inner_type)},)"
elif content_type.startswith("Dict"):
inner_type1, inner_type2 = [
i.strip() for i in content_type[5:-1].split(",")
]
return f"{{ {self._make_type_value(inner_type1)} : {self._make_type_value(inner_type2)}}}"
elif content_type.startswith("Optional"):
inner_type = content_type[9:-1]
return f"{self._make_type_value(inner_type)}"

return f"{content_type}()"

def generate(
self, protobuf_only: bool = False, language: str = PROTOCOL_LANGUAGE_PYTHON
) -> Optional[str]:
Expand Down
Loading

0 comments on commit 0816fd9

Please sign in to comment.