Skip to content

Commit

Permalink
Merge pull request #5 from Tbaile/zones-api
Browse files Browse the repository at this point in the history
Functions for zone/forwarding APIs
  • Loading branch information
gsanchietti authored Sep 15, 2023
2 parents d8bd50d + 1208dd3 commit 74d57b1
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 5 deletions.
124 changes: 122 additions & 2 deletions src/nethsec/firewall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
'''
Firewall utilities
'''

import json
import subprocess

from nethsec import utils


def add_to_zone(uci, device, zone):
'''
Add given device to a firewall zone.
Expand Down Expand Up @@ -524,3 +524,123 @@ def disable_ipv6_firewall(uci):

uci.save("firewall")
return disabled


def list_zones(uci) -> dict:
"""
Get all zones from firewall config
Args:
uci: EUci pointer
Returns:
dict with all zones
"""
return utils.get_all_by_type(uci, 'firewall', 'zone')


def list_forwardings(uci) -> dict:
"""
Get all forwardings from firewall config
Args:
uci: EUci pointer
Returns:
dict with all forwardings
"""
return utils.get_all_by_type(uci, 'firewall', 'forwarding')


def add_forwarding(uci, src: str, dest: str) -> str:
"""
Add forwarding from src to dest.
Args:
uci: EUci pointer
src: source zone, must be zone name, not config name
dest: destination zone, must be zone name, not config name
Returns:
name of forwarding config that was added
"""
config_name = utils.get_id(f'{src}2{dest}')
uci.set('firewall', config_name, 'forwarding')
uci.set('firewall', config_name, 'src', src)
uci.set('firewall', config_name, 'dest', dest)
uci.save('firewall')
return config_name


def add_zone(uci, name: str, input: str, forward: str, traffic_to_wan: bool = False, forwards_to: list[str] = None,
forwards_from: list[str] = None) -> {str, set[str]}:
"""
Add zone to firewall config.
Args:
uci: EUci pointer
name: name of zone
input: rule for input traffic, must be one of 'ACCEPT', 'REJECT', 'DROP'
forward: rule for forward traffic, must be one of 'ACCEPT', 'REJECT', 'DROP'
traffic_to_wan: if True, add forwarding from zone to wan
forwards_to: list of zones to forward traffic to
forwards_from: list of zones to forward traffic from
Returns:
tuple of zone config name and set of added forwarding configs
"""
zone_config_name = utils.get_id(name)
uci.set('firewall', zone_config_name, 'zone')
uci.set('firewall', zone_config_name, 'name', name)
uci.set('firewall', zone_config_name, 'input', input)
uci.set('firewall', zone_config_name, 'forward', forward)
uci.set('firewall', zone_config_name, 'output', 'ACCEPT')

forwardings_added = set()

if traffic_to_wan:
forwardings_added.add(add_forwarding(uci, name, 'wan'))

if forwards_to is not None:
for forward_to in forwards_to:
forwardings_added.add(add_forwarding(uci, name, forward_to))

if forwards_from is not None:
for forward_from in forwards_from:
forwardings_added.add(add_forwarding(uci, forward_from, name))

uci.save('firewall')
return zone_config_name, forwardings_added


def delete_zone(uci, zone_config_name: str) -> {str, set[str]}:
"""
Delete zone and all forwardings that are connected to it.
Args:
uci: EUci pointer
zone_config_name: name of zone config to delete
Returns:
tuple of zone config name and set of deleted forwarding configs
Raises:
ValueError: if zone_config_name is not a valid zone config name
"""
if zone_config_name not in list_zones(uci):
raise ValueError
zone_name = list_zones(uci)[zone_config_name]['name']
forwardings = list_forwardings(uci)
to_delete_forwardings = set()
for forwarding in forwardings:
if forwardings[forwarding]['src'] == zone_name:
to_delete_forwardings.add(forwarding)
if forwardings[forwarding]['dest'] == zone_name:
to_delete_forwardings.add(forwarding)

for to_delete_forwarding in to_delete_forwardings:
uci.delete('firewall', to_delete_forwarding)

uci.delete('firewall', zone_config_name)
uci.save('firewall')
return zone_config_name, to_delete_forwardings
95 changes: 92 additions & 3 deletions tests/test_firewall.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from euci import EUci, UciExceptionNotFound
from euci import EUci, UciExceptionNotFound

from nethsec import firewall

firewall_db = """
Expand Down Expand Up @@ -136,10 +137,50 @@
option target 'ACCEPT'
"""

zone_testing_db = """
config zone 'ns_lan'
option name 'lan'
option input 'ACCEPT'
option output 'ACCEPT'
option forward 'ACCEPT'
list network 'GREEN_1'
config zone 'ns_wan'
option name 'wan'
option input 'REJECT'
option output 'ACCEPT'
option forward 'REJECT'
option masq '1'
option mtu_fix '1'
list network 'wan6'
list network 'RED_2'
list network 'RED_3'
list network 'RED_1'
config zone 'ns_guests'
option name 'guests'
option input 'DROP'
option forward 'DROP'
option output 'ACCEPT'
config forwarding
option src 'lan'
option dest 'wan'
config forwarding 'ns_guests2wan'
option src 'guests'
option dest 'wan'
config forwarding 'ns_lan2guests'
option src 'lan'
option dest 'guests'
"""

def _setup_db(tmp_path):
# setup fake dbs
with tmp_path.joinpath('firewall').open('w') as fp:
fp.write(firewall_db)
fp.write(zone_testing_db)
with tmp_path.joinpath('network').open('w') as fp:
fp.write(network_db)
with tmp_path.joinpath('templates').open('w') as fp:
Expand Down Expand Up @@ -316,13 +357,13 @@ def test_add_template_service_group(tmp_path):
assert u.get("firewall", sections[1], "proto") == "udp"
assert u.get("firewall", sections[1], "dest_port") == "53"
assert u.get("firewall", sections[1], "ns_tag") == "automated"

sections = firewall.add_template_service_group(u, "ns_web_secure", "grey", "orange")
assert u.get("firewall", sections[0], "src") == "grey"
assert u.get("firewall", sections[0], "dest") == "orange"
assert u.get("firewall", sections[0], "proto") == "tcp"
assert u.get("firewall", sections[1], "proto") == "udp"

sections = firewall.add_template_service_group(u, "ns_web_secure", "blue", "yellow", link="db/mykey")
assert u.get("firewall", sections[0], "ns_link") == "db/mykey"
assert u.get("firewall", sections[1], "ns_link") == "db/mykey"
Expand Down Expand Up @@ -401,3 +442,51 @@ def test_disable_ipv6_firewall(tmp_path):
assert u.get("firewall", "v6rule", "enabled", default="1") == "1"
firewall.disable_ipv6_firewall(u)
assert u.get("firewall", "v6rule", "enabled", default="1") == "0"


def test_list_zones(tmp_path):
u = _setup_db(tmp_path)
assert firewall.list_zones(u)["ns_lan"]["name"] == "lan"
assert firewall.list_zones(u)["ns_lan"]["input"] == "ACCEPT"
assert firewall.list_zones(u)["ns_lan"]["output"] == "ACCEPT"
assert firewall.list_zones(u)["ns_lan"]["forward"] == "ACCEPT"
assert firewall.list_zones(u)["ns_lan"]["network"] == ("GREEN_1",)
assert firewall.list_zones(u)["ns_wan"]["name"] == "wan"
assert firewall.list_zones(u)["ns_wan"]["input"] == "REJECT"
assert firewall.list_zones(u)["ns_wan"]["output"] == "ACCEPT"
assert firewall.list_zones(u)["ns_wan"]["forward"] == "REJECT"
assert firewall.list_zones(u)["ns_wan"]["network"] == ("wan6", "RED_2", "RED_3", "RED_1")


def test_list_forwardings(tmp_path):
u = _setup_db(tmp_path)
assert firewall.list_forwardings(u)["ns_lan2guests"]["src"] == "lan"
assert firewall.list_forwardings(u)["ns_lan2guests"]["dest"] == "guests"
assert firewall.list_forwardings(u)["ns_guests2wan"]["src"] == "guests"
assert firewall.list_forwardings(u)["ns_guests2wan"]["dest"] == "wan"


def test_add_zone(tmp_path):
u = _setup_db(tmp_path)
assert firewall.add_zone(u, "new_zone", "REJECT", "DROP", True, ["lan"], ["lan", "guest"]) == (
"ns_new_zone", {"ns_new_zone2wan", "ns_new_zone2lan", "ns_lan2new_zone", "ns_guest2new_zone"})
assert u.get("firewall", "ns_new_zone", "name") == "new_zone"
assert u.get("firewall", "ns_new_zone", "input") == "REJECT"
assert u.get("firewall", "ns_new_zone", "output") == "ACCEPT"
assert u.get("firewall", "ns_new_zone", "forward") == "DROP"
assert u.get("firewall", "ns_new_zone2wan", "src") == "new_zone"
assert u.get("firewall", "ns_new_zone2wan", "dest") == "wan"
assert u.get("firewall", "ns_new_zone2lan", "src") == "new_zone"
assert u.get("firewall", "ns_new_zone2lan", "dest") == "lan"
assert u.get("firewall", "ns_lan2new_zone", "src") == "lan"
assert u.get("firewall", "ns_lan2new_zone", "dest") == "new_zone"
assert u.get("firewall", "ns_guest2new_zone", "src") == "guest"
assert u.get("firewall", "ns_guest2new_zone", "dest") == "new_zone"


def test_delete_zone(tmp_path):
u = _setup_db(tmp_path)
assert firewall.delete_zone(u, "ns_new_zone") == (
"ns_new_zone", {"ns_new_zone2wan", "ns_new_zone2lan", "ns_guest2new_zone", "ns_lan2new_zone"})
with pytest.raises(Exception) as e:
firewall.delete_zone(u, "not_a_zone")

0 comments on commit 74d57b1

Please sign in to comment.