Skip to content

Commit

Permalink
fix: review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tbaile committed Sep 29, 2023
1 parent 4d54578 commit 9e5fbd1
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 17 deletions.
31 changes: 16 additions & 15 deletions src/nethsec/mwan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import subprocess

import uci
from euci import EUci

from nethsec import utils
Expand Down Expand Up @@ -52,16 +53,17 @@ def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]:
ValidationError: if interface name is not defined in /etc/config/network
"""
# checking if interface is configured
available_interfaces = utils.get_all_by_type(e_uci, 'network', 'interface')
if name not in available_interfaces.keys():
try:
e_uci.get('network', name)
except uci.UciExceptionNotFound:
raise ValidationError('name', 'invalid', name)

created_interface = False
# if no interface with name exists, create one with defaults
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'interface').keys():
if e_uci.get('mwan3', name, default=None) is None:
created_interface = True
# fetch default configuration and set interface
default_interface_config = utils.get_all_by_type(e_uci, 'ns-api', 'defaults_mwan').get('defaults_mwan')
default_interface_config = e_uci.get_all('ns-api', 'defaults_mwan')
e_uci.set('mwan3', name, 'interface')
e_uci.set('mwan3', name, 'enabled', '1')
e_uci.set('mwan3', name, 'initial_state', default_interface_config['initial_state'])
Expand All @@ -81,7 +83,7 @@ def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]:

added_metric = False
# avoid adding metric if already present
if 'metric' not in available_interfaces[name]:
if 'metric' not in e_uci.get_all('network', name).keys():
added_metric = True
# generate metric
metric = __generate_metric(e_uci)
Expand All @@ -105,7 +107,7 @@ def __store_member(e_uci: EUci, interface_name: str, metric: int, weight: int) -
"""
member_config_name = utils.get_id(f'{interface_name}_M{metric}_W{weight}')
changed = False
if member_config_name not in utils.get_all_by_type(e_uci, 'mwan3', 'member').keys():
if e_uci.get('mwan3', member_config_name, default=None) is None:
changed = True
e_uci.set('mwan3', member_config_name, 'member')
e_uci.set('mwan3', member_config_name, 'interface', interface_name)
Expand Down Expand Up @@ -137,13 +139,12 @@ def store_rule(e_uci: EUci, name: str, policy: str, protocol: str = None,
"""
rule_config_name = utils.get_id(name.lower(), 15)
rules = utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys()
if rule_config_name in e_uci.get('mwan3').keys():
if e_uci.get('mwan3', rule_config_name, default=None) is not None:
raise ValidationError('name', 'unique', name)
if policy not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', policy, default=None) is None:
raise ValidationError('policy', 'invalid', policy)
e_uci.set('mwan3', rule_config_name, 'rule')
e_uci.set('mwan3', rule_config_name, 'label', name)
e_uci.set('mwan3', rule_config_name, 'label', name)
e_uci.set('mwan3', rule_config_name, 'use_policy', policy)
if protocol is not None:
e_uci.set('mwan3', rule_config_name, 'proto', protocol)
Expand Down Expand Up @@ -179,7 +180,7 @@ def store_policy(e_uci: EUci, name: str, interfaces: list[dict]) -> list[str]:
# generate policy name
policy_config_name = utils.get_id(name.lower())
# make sure name is not something that already exists
if policy_config_name in e_uci.get('mwan3').keys():
if e_uci.get('mwan3', policy_config_name, default=None) is not None:
raise ValidationError('name', 'unique', name)
# generate policy config with corresponding name
e_uci.set('mwan3', policy_config_name, 'policy')
Expand All @@ -206,7 +207,7 @@ def __fetch_interface_status(interface_name: str) -> str:
'mwan3',
'status',
'{"section": "interfaces"}'
], capture_output=True)
], capture_output=True, check=True)
.stdout.decode('utf-8'))
decoded_output = json.JSONDecoder().decode(output)
return decoded_output['interfaces'][interface_name]['status']
Expand Down Expand Up @@ -306,7 +307,7 @@ def __add_interfaces(e_uci: EUci, interfaces: list[dict], changed_config: list[s


def edit_policy(e_uci: EUci, name: str, label: str, interfaces: list[dict]) -> list[str]:
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', name, default=None) is None:
raise ValidationError('name', 'invalid', name)
changed_config = []
if label != e_uci.get_all('mwan3', name)['label']:
Expand All @@ -323,7 +324,7 @@ def edit_policy(e_uci: EUci, name: str, label: str, interfaces: list[dict]) -> l


def delete_policy(e_uci: EUci, name: str) -> list[str]:
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', name, default=None) is None:
raise ValidationError('name', 'invalid', name)
e_uci.delete('mwan3', name)
e_uci.save('mwan3')
Expand All @@ -339,7 +340,7 @@ def index_rules(e_uci: EUci) -> list[dict]:
rule_data['name'] = rule_key
rule_data['policy'] = {}
rule_data['policy']['name'] = rule_value['use_policy']
if rule_value['use_policy'] in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', rule_value['use_policy'], default=None) is not None:
rule_data['policy']['label'] = utils.get_all_by_type(e_uci, 'mwan3', 'policy')[rule_value['use_policy']]['label']
if 'label' in rule_value:
rule_data['label'] = rule_value['label']
Expand Down Expand Up @@ -387,7 +388,7 @@ def order_rules(e_uci: EUci, rules: list[str]) -> list[str]:


def delete_rule(e_uci: EUci, name: str):
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys():
if e_uci.get('mwan3', name, default=None) is None:
raise ValidationError('name', 'invalid', name)

e_uci.delete('mwan3', name)
Expand Down
67 changes: 65 additions & 2 deletions tests/test_mwan.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,5 +424,68 @@ def test_delete_rule(e_uci, mocker):
}
])
mwan.store_rule(e_uci, 'additional rule', 'ns_default')
mwan.delete_rule(e_uci, 'ns_additional_r')
assert 'ns_additional_r' not in e_uci.get_all('mwan3').keys()
assert mwan.delete_rule(e_uci, 'ns_additional_r') == 'mwan3.ns_additional_r'
assert 'ns_additional_r' not in e_uci.get_all('mwan3').keys()


def test_edit_rule(e_uci, mocker):
mocker.patch('subprocess.run')
mwan.store_policy(e_uci, 'hello world', [
{
'name': 'RED_1',
'metric': '10',
'weight': '100',
},
{
'name': 'RED_2',
'metric': '10',
'weight': '100',
}
])
mwan.store_policy(e_uci, 'cool policy', [
{
'name': 'RED_3',
'metric': '10',
'weight': '100',
},
{
'name': 'RED_1',
'metric': '10',
'weight': '100',
}
])
assert mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!', 'udp', '192.168.10.1/12', '80,443',
'0.0.0.0/0', '4040-8080') == 'mwan3.ns_default_rule'
assert e_uci.get('mwan3', 'ns_default_rule', 'label') == 'new label!'
assert e_uci.get('mwan3', 'ns_default_rule', 'use_policy') == 'ns_cool_policy'
assert e_uci.get('mwan3', 'ns_default_rule', 'proto') == 'udp'
assert e_uci.get('mwan3', 'ns_default_rule', 'src_ip') == '192.168.10.1/12'
assert e_uci.get('mwan3', 'ns_default_rule', 'src_port') == '80,443'
assert e_uci.get('mwan3', 'ns_default_rule', 'dest_ip') == '0.0.0.0/0'
assert e_uci.get('mwan3', 'ns_default_rule', 'dest_port') == '4040-8080'


def test_cant_edit_invalid_rule(e_uci, mocker):
mocker.patch('subprocess.run')
with pytest.raises(ValidationError) as e:
mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!')
assert e.value.args[0] == 'name'
assert e.value.args[1] == 'invalid'
assert e.value.args[2] == 'ns_default_rule'
mwan.store_policy(e_uci, 'hello world', [
{
'name': 'RED_1',
'metric': '10',
'weight': '100',
},
{
'name': 'RED_2',
'metric': '10',
'weight': '100',
}
])
with pytest.raises(ValidationError) as e:
mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!')
assert e.value.args[0] == 'policy'
assert e.value.args[1] == 'invalid'
assert e.value.args[2] == 'ns_cool_policy'

0 comments on commit 9e5fbd1

Please sign in to comment.