Skip to content

Commit

Permalink
fix: review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Tbaile committed Sep 29, 2023
1 parent d8b3987 commit 6befa0a
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 35 deletions.
54 changes: 24 additions & 30 deletions src/nethsec/mwan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,27 @@
import json
import subprocess

import uci
from euci import EUci

from nethsec import utils
from nethsec.utils import ValidationError


def __generate_metric(e_uci: EUci, interface_metrics: list[int] = None, metric: int = 1) -> int:
def __generate_metric(e_uci: EUci) -> int:
"""
Generates a metric for an interface.
Args:
e_uci: EUci instance
interface_metrics: list of metrics already used, will be generated if not provided
metric: metric to start from
Returns:
first metric that is not present in interface_metrics
"""
if interface_metrics is None:
interface_metrics = list[int]()
for interface in utils.get_all_by_type(e_uci, 'network', 'interface').values():
if 'metric' in interface:
interface_metrics.append(int(interface['metric']))

if metric not in interface_metrics:
return metric
else:
return __generate_metric(e_uci, interface_metrics, metric + 1)
next_metric = 0
for interface in utils.get_all_by_type(e_uci, 'network', 'interface').values():
if 'metric' in interface:
next_metric = max(next_metric, int(interface['metric']))
return next_metric + 1


def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]:
Expand All @@ -52,16 +46,17 @@ def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]:
ValidationError: if interface name is not defined in /etc/config/network
"""
# checking if interface is configured
available_interfaces = utils.get_all_by_type(e_uci, 'network', 'interface')
if name not in available_interfaces.keys():
try:
e_uci.get('network', name)
except uci.UciExceptionNotFound:
raise ValidationError('name', 'invalid', name)

created_interface = False
# if no interface with name exists, create one with defaults
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'interface').keys():
if e_uci.get('mwan3', name, default=None) is None:
created_interface = True
# fetch default configuration and set interface
default_interface_config = utils.get_all_by_type(e_uci, 'ns-api', 'defaults_mwan').get('defaults_mwan')
default_interface_config = e_uci.get_all('ns-api', 'defaults_mwan')
e_uci.set('mwan3', name, 'interface')
e_uci.set('mwan3', name, 'enabled', '1')
e_uci.set('mwan3', name, 'initial_state', default_interface_config['initial_state'])
Expand All @@ -81,7 +76,7 @@ def __store_interface(e_uci: EUci, name: str) -> tuple[bool, bool]:

added_metric = False
# avoid adding metric if already present
if 'metric' not in available_interfaces[name]:
if e_uci.get('network', name, 'metric', default=None) is None:
added_metric = True
# generate metric
metric = __generate_metric(e_uci)
Expand All @@ -105,7 +100,7 @@ def __store_member(e_uci: EUci, interface_name: str, metric: int, weight: int) -
"""
member_config_name = utils.get_id(f'{interface_name}_M{metric}_W{weight}')
changed = False
if member_config_name not in utils.get_all_by_type(e_uci, 'mwan3', 'member').keys():
if e_uci.get('mwan3', member_config_name, default=None) is None:
changed = True
e_uci.set('mwan3', member_config_name, 'member')
e_uci.set('mwan3', member_config_name, 'interface', interface_name)
Expand Down Expand Up @@ -137,13 +132,12 @@ def store_rule(e_uci: EUci, name: str, policy: str, protocol: str = None,
"""
rule_config_name = utils.get_id(name.lower(), 15)
rules = utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys()
if rule_config_name in e_uci.get('mwan3').keys():
if e_uci.get('mwan3', rule_config_name, default=None) is not None:
raise ValidationError('name', 'unique', name)
if policy not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', policy, default=None) is None:
raise ValidationError('policy', 'invalid', policy)
e_uci.set('mwan3', rule_config_name, 'rule')
e_uci.set('mwan3', rule_config_name, 'label', name)
e_uci.set('mwan3', rule_config_name, 'label', name)
e_uci.set('mwan3', rule_config_name, 'use_policy', policy)
if protocol is not None:
e_uci.set('mwan3', rule_config_name, 'proto', protocol)
Expand Down Expand Up @@ -179,7 +173,7 @@ def store_policy(e_uci: EUci, name: str, interfaces: list[dict]) -> list[str]:
# generate policy name
policy_config_name = utils.get_id(name.lower())
# make sure name is not something that already exists
if policy_config_name in e_uci.get('mwan3').keys():
if e_uci.get('mwan3', policy_config_name, default=None) is not None:
raise ValidationError('name', 'unique', name)
# generate policy config with corresponding name
e_uci.set('mwan3', policy_config_name, 'policy')
Expand All @@ -206,7 +200,7 @@ def __fetch_interface_status(interface_name: str) -> str:
'mwan3',
'status',
'{"section": "interfaces"}'
], capture_output=True)
], capture_output=True, check=True)
.stdout.decode('utf-8'))
decoded_output = json.JSONDecoder().decode(output)
return decoded_output['interfaces'][interface_name]['status']
Expand Down Expand Up @@ -306,7 +300,7 @@ def __add_interfaces(e_uci: EUci, interfaces: list[dict], changed_config: list[s


def edit_policy(e_uci: EUci, name: str, label: str, interfaces: list[dict]) -> list[str]:
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', name, default=None) is None:
raise ValidationError('name', 'invalid', name)
changed_config = []
if label != e_uci.get_all('mwan3', name)['label']:
Expand All @@ -323,7 +317,7 @@ def edit_policy(e_uci: EUci, name: str, label: str, interfaces: list[dict]) -> l


def delete_policy(e_uci: EUci, name: str) -> list[str]:
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', name, default=None) is None:
raise ValidationError('name', 'invalid', name)
e_uci.delete('mwan3', name)
e_uci.save('mwan3')
Expand All @@ -339,7 +333,7 @@ def index_rules(e_uci: EUci) -> list[dict]:
rule_data['name'] = rule_key
rule_data['policy'] = {}
rule_data['policy']['name'] = rule_value['use_policy']
if rule_value['use_policy'] in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', rule_value['use_policy'], default=None) is not None:
rule_data['policy']['label'] = utils.get_all_by_type(e_uci, 'mwan3', 'policy')[rule_value['use_policy']]['label']
if 'label' in rule_value:
rule_data['label'] = rule_value['label']
Expand Down Expand Up @@ -387,7 +381,7 @@ def order_rules(e_uci: EUci, rules: list[str]) -> list[str]:


def delete_rule(e_uci: EUci, name: str):
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys():
if e_uci.get('mwan3', name, default=None) is None:
raise ValidationError('name', 'invalid', name)

e_uci.delete('mwan3', name)
Expand All @@ -398,10 +392,10 @@ def delete_rule(e_uci: EUci, name: str):
def edit_rule(e_uci: EUci, name: str, policy: str, label: str, protocol: str = None,
source_address: str = None, source_port: str = None,
destination_address: str = None, destination_port: str = None):
if name not in utils.get_all_by_type(e_uci, 'mwan3', 'rule').keys():
if e_uci.get('mwan3', name, default=None) is None:
raise ValidationError('name', 'invalid', name)

if policy not in utils.get_all_by_type(e_uci, 'mwan3', 'policy').keys():
if e_uci.get('mwan3', policy, default=None) is None:
raise ValidationError('policy', 'invalid', policy)
e_uci.set('mwan3', name, 'use_policy', policy)
e_uci.set('mwan3', name, 'label', label)
Expand Down
75 changes: 70 additions & 5 deletions tests/test_mwan.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,11 @@ def test_create_unique_mwan(e_uci, mocker):

def test_metric_generation(e_uci):
assert mwan.__generate_metric(e_uci) == 1
assert mwan.__generate_metric(e_uci, [1, 4]) == 2
assert mwan.__generate_metric(e_uci, [1, 2, 4]) == 3
assert mwan.__generate_metric(e_uci, [4, 3, 1]) == 2
assert mwan.__store_interface(e_uci, 'RED_1') == (True, True)
assert mwan.__generate_metric(e_uci) == 2
assert mwan.__generate_metric(e_uci) == 2
assert mwan.__store_interface(e_uci, 'RED_2') == (True, True)
assert mwan.__generate_metric(e_uci) == 3


def test_list_policies(e_uci, mocker):
Expand Down Expand Up @@ -424,5 +426,68 @@ def test_delete_rule(e_uci, mocker):
}
])
mwan.store_rule(e_uci, 'additional rule', 'ns_default')
mwan.delete_rule(e_uci, 'ns_additional_r')
assert 'ns_additional_r' not in e_uci.get_all('mwan3').keys()
assert mwan.delete_rule(e_uci, 'ns_additional_r') == 'mwan3.ns_additional_r'
assert 'ns_additional_r' not in e_uci.get_all('mwan3').keys()


def test_edit_rule(e_uci, mocker):
mocker.patch('subprocess.run')
mwan.store_policy(e_uci, 'hello world', [
{
'name': 'RED_1',
'metric': '10',
'weight': '100',
},
{
'name': 'RED_2',
'metric': '10',
'weight': '100',
}
])
mwan.store_policy(e_uci, 'cool policy', [
{
'name': 'RED_3',
'metric': '10',
'weight': '100',
},
{
'name': 'RED_1',
'metric': '10',
'weight': '100',
}
])
assert mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!', 'udp', '192.168.10.1/12', '80,443',
'0.0.0.0/0', '4040-8080') == 'mwan3.ns_default_rule'
assert e_uci.get('mwan3', 'ns_default_rule', 'label') == 'new label!'
assert e_uci.get('mwan3', 'ns_default_rule', 'use_policy') == 'ns_cool_policy'
assert e_uci.get('mwan3', 'ns_default_rule', 'proto') == 'udp'
assert e_uci.get('mwan3', 'ns_default_rule', 'src_ip') == '192.168.10.1/12'
assert e_uci.get('mwan3', 'ns_default_rule', 'src_port') == '80,443'
assert e_uci.get('mwan3', 'ns_default_rule', 'dest_ip') == '0.0.0.0/0'
assert e_uci.get('mwan3', 'ns_default_rule', 'dest_port') == '4040-8080'


def test_cant_edit_invalid_rule(e_uci, mocker):
mocker.patch('subprocess.run')
with pytest.raises(ValidationError) as e:
mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!')
assert e.value.args[0] == 'name'
assert e.value.args[1] == 'invalid'
assert e.value.args[2] == 'ns_default_rule'
mwan.store_policy(e_uci, 'hello world', [
{
'name': 'RED_1',
'metric': '10',
'weight': '100',
},
{
'name': 'RED_2',
'metric': '10',
'weight': '100',
}
])
with pytest.raises(ValidationError) as e:
mwan.edit_rule(e_uci, 'ns_default_rule', 'ns_cool_policy', 'new label!')
assert e.value.args[0] == 'policy'
assert e.value.args[1] == 'invalid'
assert e.value.args[2] == 'ns_cool_policy'

0 comments on commit 6befa0a

Please sign in to comment.