diff --git a/src/nethsec/mwan/__init__.py b/src/nethsec/mwan/__init__.py index db93117b..ccc0f987 100644 --- a/src/nethsec/mwan/__init__.py +++ b/src/nethsec/mwan/__init__.py @@ -8,33 +8,27 @@ import json import subprocess +import uci from euci import EUci from nethsec import utils from nethsec.utils import ValidationError -def __generate_metric(e_uci: EUci, interface_metrics: list[int] = None, metric: int = 1) -> int: +def __generate_metric(e_uci: EUci) -> int: """ Generates a metric for an interface. Args: e_uci: EUci instance - interface_metrics: list of metrics already used, will be generated if not provided - metric: metric to start from Returns: first metric that is not present in interface_metrics """ - if interface_metrics is None: - interface_metrics = list[int]() - for interface in utils.get_all_by_type(e_uci, 'network', 'interface').values(): - if 'metric' in interface: - interface_metrics.append(int(interface['metric'])) - - if metric not in interface_metrics: - return metric - else: - return __generate_metric(e_uci, interface_metrics, metric + 1) + next_metric = 0 + for interface in utils.get_all_by_type(e_uci, 'network', 'interface').values(): + if 'metric' in interface: + next_metric = max(next_metric, int(interface['metric'])) + return next_metric + 1 def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]: @@ -52,16 +46,17 @@ def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]: ValidationError: if interface name is not defined in /etc/config/network """ # checking if interface is configured - available_interfaces = utils.get_all_by_type(e_uci, 'network', 'interface') - if name not in available_interfaces.keys(): + try: + e_uci.get('network', name) + except uci.UciExceptionNotFound: raise ValidationError('name', 'invalid', name) created_interface = False # if no interface with name exists, create one with defaults - if name not in utils.get_all_by_type(e_uci, 'mwan3', 'interface').keys(): + if e_uci.get('mwan3', name, default=None) is None: created_interface = True # fetch default configuration and set interface - default_interface_config = utils.get_all_by_type(e_uci, 'ns-api', 'defaults_mwan').get('defaults_mwan') + default_interface_config = e_uci.get_all('ns-api', 'defaults_mwan') e_uci.set('mwan3', name, 'interface') e_uci.set('mwan3', name, 'enabled', '1') e_uci.set('mwan3', name, 'initial_state', default_interface_config['initial_state']) @@ -81,7 +76,7 @@ def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]: added_metric = False # avoid adding metric if already present - if 'metric' not in available_interfaces[name]: + if e_uci.get('network', name, 'metric', default=None) is None: added_metric = True # generate metric metric = __generate_metric(e_uci) @@ -105,7 +100,7 @@ def __store_member(e_uci: EUci, interface_name: str, metric: int, weight: int) - """ member_config_name = utils.get_id(f'{interface_name}_M{metric}_W{weight}') changed = False - if member_config_name not in utils.get_all_by_type(e_uci, 'mwan3', 'member').keys(): + if e_uci.get('mwan3', member_config_name, default=None) is None: changed = True e_uci.set('mwan3', member_config_name, 'member') e_uci.set('mwan3', member_config_name, 'interface', interface_name) @@ -137,13 +132,12 @@ def store_rule(e_uci: EUci, name: str, policy: str, protocol: str = None, """ rule_config_name = utils.get_id(name.lower(), 15) rules = utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys() - if rule_config_name in e_uci.get('mwan3').keys(): + if e_uci.get('mwan3', rule_config_name, default=None) is not None: raise ValidationError('name', 'unique', name) - if policy not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys(): + if e_uci.get('mwan3', policy, default=None) is None: raise ValidationError('policy', 'invalid', policy) e_uci.set('mwan3', rule_config_name, 'rule') e_uci.set('mwan3', rule_config_name, 'label', name) - e_uci.set('mwan3', rule_config_name, 'label', name) e_uci.set('mwan3', rule_config_name, 'use_policy', policy) if protocol is not None: e_uci.set('mwan3', rule_config_name, 'proto', protocol) @@ -179,7 +173,7 @@ def store_policy(e_uci: EUci, name: str, interfaces: list[dict]) -> list[str]: # generate policy name policy_config_name = utils.get_id(name.lower()) # make sure name is not something that already exists - if policy_config_name in e_uci.get('mwan3').keys(): + if e_uci.get('mwan3', policy_config_name, default=None) is not None: raise ValidationError('name', 'unique', name) # generate policy config with corresponding name e_uci.set('mwan3', policy_config_name, 'policy') @@ -206,7 +200,7 @@ def __fetch_interface_status(interface_name: str) -> str: 'mwan3', 'status', '{"section": "interfaces"}' - ], capture_output=True) + ], capture_output=True, check=True) .stdout.decode('utf-8')) decoded_output = json.JSONDecoder().decode(output) return decoded_output['interfaces'][interface_name]['status'] @@ -306,7 +300,7 @@ def __add_interfaces(e_uci: EUci, interfaces: list[dict], changed_config: list[s def edit_policy(e_uci: EUci, name: str, label: str, interfaces: list[dict]) -> list[str]: - if name not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys(): + if e_uci.get('mwan3', name, default=None) is None: raise ValidationError('name', 'invalid', name) changed_config = [] if label != e_uci.get_all('mwan3', name)['label']: @@ -323,7 +317,7 @@ def edit_policy(e_uci: EUci, name: str, label: str, interfaces: list[dict]) -> l def delete_policy(e_uci: EUci, name: str) -> list[str]: - if name not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys(): + if e_uci.get('mwan3', name, default=None) is None: raise ValidationError('name', 'invalid', name) e_uci.delete('mwan3', name) e_uci.save('mwan3') @@ -339,7 +333,7 @@ def index_rules(e_uci: EUci) -> list[dict]: rule_data['name'] = rule_key rule_data['policy'] = {} rule_data['policy']['name'] = rule_value['use_policy'] - if rule_value['use_policy'] in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys(): + if e_uci.get('mwan3', rule_value['use_policy'], default=None) is not None: rule_data['policy']['label'] = utils.get_all_by_type(e_uci, 'mwan3', 'policy')[rule_value['use_policy']]['label'] if 'label' in rule_value: rule_data['label'] = rule_value['label'] @@ -387,7 +381,7 @@ def order_rules(e_uci: EUci, rules: list[str]) -> list[str]: def delete_rule(e_uci: EUci, name: str): - if name not in utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys(): + if e_uci.get('mwan3', name, default=None) is None: raise ValidationError('name', 'invalid', name) e_uci.delete('mwan3', name) @@ -398,10 +392,10 @@ def delete_rule(e_uci: EUci, name: str): def edit_rule(e_uci: EUci, name: str, policy: str, label: str, protocol: str = None, source_address: str = None, source_port: str = None, destination_address: str = None, destination_port: str = None): - if name not in utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys(): + if e_uci.get('mwan3', name, default=None) is None: raise ValidationError('name', 'invalid', name) - if policy not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys(): + if e_uci.get('mwan3', policy, default=None) is None: raise ValidationError('policy', 'invalid', policy) e_uci.set('mwan3', name, 'use_policy', policy) e_uci.set('mwan3', name, 'label', label) diff --git a/tests/test_mwan.py b/tests/test_mwan.py index d593c747..283d36fc 100644 --- a/tests/test_mwan.py +++ b/tests/test_mwan.py @@ -175,9 +175,11 @@ def test_create_unique_mwan(e_uci, mocker): def test_metric_generation(e_uci): assert mwan.__generate_metric(e_uci) == 1 - assert mwan.__generate_metric(e_uci, [1, 4]) == 2 - assert mwan.__generate_metric(e_uci, [1, 2, 4]) == 3 - assert mwan.__generate_metric(e_uci, [4, 3, 1]) == 2 + assert mwan.__store_interface(e_uci, 'RED_1') == (True, True) + assert mwan.__generate_metric(e_uci) == 2 + assert mwan.__generate_metric(e_uci) == 2 + assert mwan.__store_interface(e_uci, 'RED_2') == (True, True) + assert mwan.__generate_metric(e_uci) == 3 def test_list_policies(e_uci, mocker): @@ -424,5 +426,68 @@ def test_delete_rule(e_uci, mocker): } ]) mwan.store_rule(e_uci, 'additional rule', 'ns_default') - mwan.delete_rule(e_uci, 'ns_additional_r') - assert 'ns_additional_r' not in e_uci.get_all('mwan3').keys() \ No newline at end of file + assert mwan.delete_rule(e_uci, 'ns_additional_r') == 'mwan3.ns_additional_r' + assert 'ns_additional_r' not in e_uci.get_all('mwan3').keys() + + +def test_edit_rule(e_uci, mocker): + mocker.patch('subprocess.run') + mwan.store_policy(e_uci, 'hello world', [ + { + 'name': 'RED_1', + 'metric': '10', + 'weight': '100', + }, + { + 'name': 'RED_2', + 'metric': '10', + 'weight': '100', + } + ]) + mwan.store_policy(e_uci, 'cool policy', [ + { + 'name': 'RED_3', + 'metric': '10', + 'weight': '100', + }, + { + 'name': 'RED_1', + 'metric': '10', + 'weight': '100', + } + ]) + assert mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!', 'udp', '192.168.10.1/12', '80,443', + '0.0.0.0/0', '4040-8080') == 'mwan3.ns_default_rule' + assert e_uci.get('mwan3', 'ns_default_rule', 'label') == 'new label!' + assert e_uci.get('mwan3', 'ns_default_rule', 'use_policy') == 'ns_cool_policy' + assert e_uci.get('mwan3', 'ns_default_rule', 'proto') == 'udp' + assert e_uci.get('mwan3', 'ns_default_rule', 'src_ip') == '192.168.10.1/12' + assert e_uci.get('mwan3', 'ns_default_rule', 'src_port') == '80,443' + assert e_uci.get('mwan3', 'ns_default_rule', 'dest_ip') == '0.0.0.0/0' + assert e_uci.get('mwan3', 'ns_default_rule', 'dest_port') == '4040-8080' + + +def test_cant_edit_invalid_rule(e_uci, mocker): + mocker.patch('subprocess.run') + with pytest.raises(ValidationError) as e: + mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!') + assert e.value.args[0] == 'name' + assert e.value.args[1] == 'invalid' + assert e.value.args[2] == 'ns_default_rule' + mwan.store_policy(e_uci, 'hello world', [ + { + 'name': 'RED_1', + 'metric': '10', + 'weight': '100', + }, + { + 'name': 'RED_2', + 'metric': '10', + 'weight': '100', + } + ]) + with pytest.raises(ValidationError) as e: + mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!') + assert e.value.args[0] == 'policy' + assert e.value.args[1] == 'invalid' + assert e.value.args[2] == 'ns_cool_policy'