diff --git a/aea/configurations/base.py b/aea/configurations/base.py index 080ba7bcbc..50dd99fed8 100644 --- a/aea/configurations/base.py +++ b/aea/configurations/base.py @@ -81,6 +81,7 @@ SimpleId, SimpleIdOrStr, load_module, + perform_dict_override, recursive_update, ) from aea.helpers.ipfs.base import IPFSHashOnly @@ -1580,7 +1581,12 @@ def all_components_id(self) -> List[ComponentId]: return result - def update(self, data: Dict, env_vars_friendly: bool = False) -> None: + def update( # pylint: disable=arguments-differ + self, + data: Dict, + env_vars_friendly: bool = False, + dict_overrides: Optional[Dict] = None, + ) -> None: """ Update configuration with other data. @@ -1589,6 +1595,7 @@ def update(self, data: Dict, env_vars_friendly: bool = False) -> None: :param data: the data to replace. :param env_vars_friendly: whether or not it is env vars friendly. + :param dict_overrides: A dictionary containing mapping for Component ID -> List of paths """ data = copy(data) # update component parts @@ -1599,6 +1606,7 @@ def update(self, data: Dict, env_vars_friendly: bool = False) -> None: for component_id, obj in new_component_configurations.items(): if component_id not in updated_component_configurations: updated_component_configurations[component_id] = obj + else: recursive_update( updated_component_configurations[component_id], @@ -1606,6 +1614,14 @@ def update(self, data: Dict, env_vars_friendly: bool = False) -> None: allow_new_values=True, ) + if dict_overrides is not None and component_id in dict_overrides: + perform_dict_override( + component_id, + dict_overrides, + updated_component_configurations, + new_component_configurations, + ) + self.check_overrides_valid(data, env_vars_friendly=env_vars_friendly) super().update(data, env_vars_friendly=env_vars_friendly) self.validate_config_data(self.json, env_vars_friendly=env_vars_friendly) diff --git a/aea/configurations/manager.py b/aea/configurations/manager.py index 349ad2a5e1..9e617e3bfc 100644 --- a/aea/configurations/manager.py +++ b/aea/configurations/manager.py @@ -20,6 +20,7 @@ """Implementation of the AgentConfigManager.""" import json import os +from collections import OrderedDict from copy import deepcopy from pathlib import Path from typing import Callable, Dict, List, NewType, Optional, Set, Tuple, Union, cast @@ -244,7 +245,9 @@ def handle_dotted_path( ) # find path to the resource directory - path_to_resource_directory = Path(".") / resource_type_plural / resource_name + path_to_resource_directory = ( + aea_project_path / resource_type_plural / resource_name + ) path_to_resource_configuration = ( path_to_resource_directory / RESOURCE_TYPE_TO_CONFIG_FILE[resource_type_plural] @@ -381,7 +384,15 @@ def set_variable(self, path: VariablePath, value: JSON_TYPES) -> None: # agent overrides.update(data) - self.update_config(overrides) + dict_overrides: Optional[Dict] = None + if isinstance(value, (dict, OrderedDict)): + dict_overrides = { + component_id: [ + json_path, + ] + } + + self.update_config(overrides, dict_overrides=dict_overrides) @staticmethod def _make_dict_for_path_and_value(json_path: JsonPath, value: JSON_TYPES) -> Dict: @@ -484,7 +495,11 @@ def _parse_path(self, path: VariablePath) -> Tuple[Optional[ComponentId], JsonPa ) return component_id, json_path - def update_config(self, overrides: Dict) -> None: + def update_config( + self, + overrides: Dict, + dict_overrides: Optional[Dict] = None, + ) -> None: """ Apply overrides for agent config. @@ -492,6 +507,7 @@ def update_config(self, overrides: Dict) -> None: Does not save it on the disc! :param overrides: overridden values dictionary + :param dict_overrides: A dictionary containing mapping for Component ID -> List of paths :return: None """ @@ -510,7 +526,11 @@ def update_config(self, overrides: Dict) -> None: obj, env_vars_friendly=self.env_vars_friendly ) - self.agent_config.update(overrides, env_vars_friendly=self.env_vars_friendly) + self.agent_config.update( + overrides, + env_vars_friendly=self.env_vars_friendly, + dict_overrides=dict_overrides, + ) def _filter_overrides(self, overrides: Dict) -> Dict: """Stay only updated values for agent config.""" diff --git a/aea/configurations/validation.py b/aea/configurations/validation.py index eae294d06a..1c77034aaf 100644 --- a/aea/configurations/validation.py +++ b/aea/configurations/validation.py @@ -21,6 +21,7 @@ import inspect import json import os +from collections import OrderedDict from copy import deepcopy from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -309,12 +310,27 @@ def check_excludes(path: Tuple[str, ...]) -> bool: return True return False + def is_a_dict_override(path: Tuple[str, ...]) -> bool: + """Check if an override is a dict override.""" + flag = False + while len(path) > 0: + path = path[:-1] + if path in pattern_path_value: + pattern_value = pattern_path_value[path] + flag = isinstance(pattern_value, OrderedDict) + break + return flag + for path, new_value in data_path_value.items(): if check_excludes(path): continue if path not in pattern_path_value: - errors.append(f"Attribute `{'.'.join(path)}` is not allowed to be updated!") + if not is_a_dict_override(path=(*path,)): + errors.append( + f"Attribute `{'.'.join(path)}` is not allowed to be updated!" + ) + continue pattern_value = pattern_path_value[path] @@ -330,6 +346,9 @@ def check_excludes(path: Tuple[str, ...]) -> bool: # one of the values is env variable: skip data type check continue + if isinstance(pattern_value, OrderedDict) and isinstance(new_value, dict): + continue + if ( not issubclass(type(new_value), type(pattern_value)) and new_value is not None diff --git a/aea/helpers/base.py b/aea/helpers/base.py index 4e90630d7a..27823bc5f4 100644 --- a/aea/helpers/base.py +++ b/aea/helpers/base.py @@ -485,6 +485,33 @@ def recursive_update( to_update[key] = value +def perform_dict_override( + component_id: Any, + overrides: Dict, + updated_configuration: Dict, + new_configuration: Dict, +) -> None: + """ + Perform recursive dict override. + + :param component_id: Component ID for which the updated will be performed + :param overrides: A dictionary containing mapping for Component ID -> List of paths + :param updated_configuration: Configuration which needs to be updated + :param new_configuration: Configuration from which the method will perform the update + """ + for path in overrides[component_id]: + + will_be_updated = updated_configuration[component_id] + update = new_configuration[component_id] + + *params, update_param = path + for param in params: + will_be_updated = will_be_updated[param] + update = update[param] + + will_be_updated[update_param] = update[update_param] + + def _get_aea_logger_name_prefix(module_name: str, agent_name: str) -> str: """ Get the logger name prefix. diff --git a/docs/api/configurations/manager.md b/docs/api/configurations/manager.md index 08f20e2f07..5572dfa4e0 100644 --- a/docs/api/configurations/manager.md +++ b/docs/api/configurations/manager.md @@ -173,7 +173,8 @@ json friendly value. #### update`_`config ```python -def update_config(overrides: Dict) -> None +def update_config(overrides: Dict, + dict_overrides: Optional[Dict] = None) -> None ``` Apply overrides for agent config. @@ -184,6 +185,7 @@ Does not save it on the disc! **Arguments**: - `overrides`: overridden values dictionary +- `dict_overrides`: A dictionary containing mapping for Component ID -> List of paths **Returns**: diff --git a/docs/api/helpers/base.md b/docs/api/helpers/base.md index 578de3ad9c..cd9c766cd9 100644 --- a/docs/api/helpers/base.md +++ b/docs/api/helpers/base.md @@ -342,6 +342,25 @@ It does side-effects to the first dictionary. - `new_values`: the dictionary of new values to replace. - `allow_new_values`: whether or not to allow new values. + + +#### perform`_`dict`_`override + +```python +def perform_dict_override(component_id: Any, overrides: Dict, + updated_configuration: Dict, + new_configuration: Dict) -> None +``` + +Perform recursive dict override. + +**Arguments**: + +- `component_id`: Component ID for which the updated will be performed +- `overrides`: A dictionary containing mapping for Component ID -> List of paths +- `updated_configuration`: Configuration which needs to be updated +- `new_configuration`: Configuration from which the method will perform the update + #### find`_`topological`_`order diff --git a/tests/data/dummy_aea/aea-config.yaml b/tests/data/dummy_aea/aea-config.yaml index 6270178fdc..560840d36d 100644 --- a/tests/data/dummy_aea/aea-config.yaml +++ b/tests/data/dummy_aea/aea-config.yaml @@ -37,3 +37,10 @@ connection_private_key_paths: fetchai: fetchai_private_key.txt default_routing: {} dependencies: {} +--- +public_id: dummy_author/test_skill:0.1.0 +type: skill +models: + scaffold: + args: + recursive: {} \ No newline at end of file diff --git a/tests/data/dummy_aea/skills/test_skill/skill.yaml b/tests/data/dummy_aea/skills/test_skill/skill.yaml index c34fe31d5f..c7a83b340c 100644 --- a/tests/data/dummy_aea/skills/test_skill/skill.yaml +++ b/tests/data/dummy_aea/skills/test_skill/skill.yaml @@ -30,6 +30,7 @@ models: scaffold: args: foo: bar + recursive: {} class_name: MyModel dependencies: {} is_abstract: false diff --git a/tests/test_aea_builder.py b/tests/test_aea_builder.py index 1f0a086284..87838ba766 100644 --- a/tests/test_aea_builder.py +++ b/tests/test_aea_builder.py @@ -1052,16 +1052,14 @@ def test_dependency_tree_check(): dummy_aea_path = Path(CUR_PATH, "data", "dummy_aea") aea_config_file = dummy_aea_path / DEFAULT_AEA_CONFIG_FILE original_content = aea_config_file.read_text() - missing_dependencies = original_content.replace( - "- dummy_author/test_skill:0.1.0\n", "" - ) + missing_dependencies = original_content.replace("- dummy_author/dummy:0.1.0\n", "") aea_config_file.write_text(missing_dependencies) try: with pytest.raises( AEAEnforceError, match=re.escape( - "Following dependencies are present in the project but missing from the aea-config.yaml; {PackageId(skill, dummy_author/test_skill:0.1.0)}" + "Following dependencies are present in the project but missing from the aea-config.yaml; {PackageId(skill, dummy_author/dummy:0.1.0)}" ), ): AEABuilder.from_aea_project(dummy_aea_path) diff --git a/tests/test_cli/test_config.py b/tests/test_cli/test_config.py index d66c080afd..5e03126180 100644 --- a/tests/test_cli/test_config.py +++ b/tests/test_cli/test_config.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # ------------------------------------------------------------------------------ # -# Copyright 2021 Valory AG +# Copyright 2021-2022 Valory AG # Copyright 2018-2019 Fetch.AI Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -676,8 +676,6 @@ def get_component_config_value(self) -> dict: def test_set_get_correct_path(self): """Test component value updated in agent config not in component config.""" agent_config = self.load_agent_config() - assert not agent_config.component_configurations - config_value = self.get_component_config_value() assert config_value == self.INITIAL_VALUE diff --git a/tests/test_configurations/test_base.py b/tests/test_configurations/test_base.py index 1a1ded4818..1374fc14cc 100644 --- a/tests/test_configurations/test_base.py +++ b/tests/test_configurations/test_base.py @@ -366,7 +366,7 @@ def test_all_components_id(self): def test_component_configurations_setter(self): """Test component configuration setter.""" - assert self.aea_config.component_configurations == {} + assert len(self.aea_config.component_configurations) == 1 new_component_configurations = { self.dummy_skill_component_id: self.new_dummy_skill_config } @@ -374,7 +374,7 @@ def test_component_configurations_setter(self): def test_component_configurations_setter_negative(self): """Test component configuration setter with wrong configurations.""" - assert self.aea_config.component_configurations == {} + assert len(self.aea_config.component_configurations) == 1 new_component_configurations = { self.dummy_skill_component_id: { "handlers": {"dummy": {"class_name": "SomeClass"}} diff --git a/tests/test_configurations/test_manager.py b/tests/test_configurations/test_manager.py index cbc2240d28..7fdfcf9d24 100644 --- a/tests/test_configurations/test_manager.py +++ b/tests/test_configurations/test_manager.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # ------------------------------------------------------------------------------ # -# Copyright 2021 Valory AG +# Copyright 2021-2022 Valory AG # Copyright 2018-2019 Fetch.AI Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,6 +21,7 @@ import os from copy import deepcopy from pathlib import Path +from typing import Dict, cast from unittest.mock import mock_open, patch import pytest @@ -40,9 +41,7 @@ DUMMY_AEA = Path(ROOT_DIR) / "tests" / "data" / "dummy_aea" - -agent_config_data = yaml.safe_load( - """ +DUMMY_AEA_CONFIG = """ agent_name: Agent0 author: dummy_author version: 1.0.0 @@ -72,12 +71,23 @@ ethereum: ethereum_private_key.txt default_routing: {} """ -) + +DUMMY_SKILL_OVERRIDE = """ +public_id: dummy_author/test_skill:0.1.0 +type: skill +models: + scaffold: + args: + recursive: + hello: world +""" + +AGENT_CONFIG_DATA = yaml.safe_load(DUMMY_AEA_CONFIG) def test_envvars_applied(): """Test env vars replaced with values.""" - dct = deepcopy(agent_config_data) + dct = deepcopy(AGENT_CONFIG_DATA) with patch.object(AgentConfigManager, "_load_config_data", return_value=[dct]): os.environ["DISABLE_LOGS"] = "true" agent_config_manager = AgentConfigManager.load(".", substitude_env_vars=True) @@ -129,7 +139,7 @@ def test_envvars_applied(): ) # not applied - dct = deepcopy(agent_config_data) + dct = deepcopy(AGENT_CONFIG_DATA) with patch.object(AgentConfigManager, "_load_config_data", return_value=[dct]): os.environ["DISABLE_LOGS"] = "true" agent_config_manager = AgentConfigManager.load(".", substitude_env_vars=False) @@ -142,7 +152,7 @@ def test_envvars_applied(): @patch.object(AgentConfigManager, "get_overridables", return_value=[{}, {}]) def test_envvars_preserved(*mocks): """Test env vars not modified on config update.""" - dct = deepcopy(agent_config_data) + dct = deepcopy(AGENT_CONFIG_DATA) new_cosmos_key_value = "cosmons_key_updated" with patch.object(AgentConfigManager, "_load_config_data", return_value=[dct]): @@ -176,7 +186,7 @@ def test_envvars_preserved(*mocks): def test_agent_attribute_get_set(): """Test agent config manager get set variables.""" - dct = deepcopy(agent_config_data) + dct = deepcopy(AGENT_CONFIG_DATA) with patch.object(AgentConfigManager, "_load_config_data", return_value=[dct]): os.environ["DISABLE_LOGS"] = "true" agent_config_manager = AgentConfigManager.load( @@ -235,6 +245,48 @@ def test_agent_attribute_get_set(): agent_config_manager.verify_private_keys(DUMMY_AEA, lambda x, y, z: None) +def test_recursive_updates() -> None: + """Test recursive updates.""" + agent_config_manager = AgentConfigManager.load(DUMMY_AEA, substitude_env_vars=True) + agent_config_manager.set_variable( + "skills.test_skill.models.scaffold.args.recursive", {"foo": "bar"} + ) + value = cast( + Dict, + agent_config_manager.get_variable( + "skills.test_skill.models.scaffold.args.recursive" + ), + ) + + assert value == {"foo": "bar"} + + agent_config_manager.set_variable( + "skills.test_skill.models.scaffold.args.recursive", + {"hello": "world"}, + ) + value = cast( + Dict, + agent_config_manager.get_variable( + "skills.test_skill.models.scaffold.args.recursive" + ), + ) + + assert value == {"hello": "world"} + + agent_config_manager.set_variable( + "skills.test_skill.models.scaffold.args.recursive.hello", + "world_0", + ) + value = cast( + Dict, + agent_config_manager.get_variable( + "skills.test_skill.models.scaffold.args.recursive" + ), + ) + + assert value == {"hello": "world_0"} + + def test_agent_attribute_get_overridables(): """Test AgentConfigManager.get_overridables.""" agent_config_manager = AgentConfigManager.load(DUMMY_AEA, substitude_env_vars=False)