From 43a1e638f9cd39747034e4aa6bdb390b247c791f Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 15 Nov 2024 18:12:03 +0000 Subject: [PATCH 01/30] Add NFTablesRuleController --- Dockerfile | 7 +- core/schains/firewall/firewall_manager.py | 9 + core/schains/firewall/nftables.py | 288 ++++++++++++++++++++++ core/schains/firewall/rule_controller.py | 12 +- core/schains/firewall/utils.py | 37 ++- tests.Dockerfile | 6 +- tests/firewall/nftables_test.py | 105 ++++++++ 7 files changed, 458 insertions(+), 6 deletions(-) create mode 100644 core/schains/firewall/nftables.py create mode 100644 tests/firewall/nftables_test.py diff --git a/Dockerfile b/Dockerfile index 0a2c24bb3..540b256f6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM python:3.11-bookworm -RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig +RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig nftables python3-nftables RUN mkdir /usr/src/admin WORKDIR /usr/src/admin @@ -8,12 +8,13 @@ WORKDIR /usr/src/admin COPY requirements.txt ./ COPY requirements-dev.txt ./ -RUN pip3 install --no-cache-dir -r requirements.txt +RUN pip3 install -r requirements.txt COPY . . RUN update-alternatives --set iptables /usr/sbin/iptables-legacy && \ update-alternatives --set ip6tables /usr/sbin/ip6tables-legacy -ENV PYTHONPATH="/usr/src/admin" +ENV PYTHONPATH="/usr/src/admin":/usr/lib/python3/dist-packages/ + ENV COLUMNS=80 diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index e2216fc73..43c5cb81d 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -22,6 +22,7 @@ from typing import Iterable, Optional from core.schains.firewall.iptables import IptablesController +from core.schains.firewall.nftables import NftablesController from core.schains.firewall.types import ( IFirewallManager, IHostFirewallController, @@ -88,3 +89,11 @@ def flush(self) -> None: class IptablesSChainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> IptablesController: return IptablesController() + + +class NftSchainFirewallManager(SChainFirewallManager): + def create_host_controller(self) -> NftablesController: + nc_controller = NftablesController(chain=self.name) + nc_controller.create_table() + nc_controller.create_chain() + return nc_controller diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py new file mode 100644 index 000000000..32cc24405 --- /dev/null +++ b/core/schains/firewall/nftables.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE Admin +# +# Copyright (C) 2024 SKALE Labs +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +import logging +import importlib +import ipaddress +import multiprocessing +from functools import wraps +from typing import Callable, Iterable + +from core.schains.firewall.types import IHostFirewallController, SChainRule + +from typing import TypeVar +import json + +T = TypeVar('T') + + +logger = logging.getLogger(__name__) + +TABLE = 'filter' +CHAIN = 'INPUT' + + +def refreshed(func: Callable) -> Callable: + @wraps(func) + def wrapper(self, *args, **kwargs): + self.refresh() + return func(self, *args, **kwargs) + + return wrapper + + +def is_like_number(value): + if value is None: + return False + try: + int(value) + except ValueError: + return False + return True + + +class NftablesCmdFailedError(Exception): + pass + + +class NftablesController(IHostFirewallController): + plock = multiprocessing.Lock() + FAMILY = 'inet' + + def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: + self.table = table + self.chain = chain + self._nftables = importlib.import_module('nftables') + self.nft = self._nftables.Nftables() + self.nft.set_json_output(True) + + def _compose_json(self, commands: list[dict]) -> dict: + json_cmd = {'nftables': commands} + self.nft.json_validate(json_cmd) + return json_cmd + + def create_table(self) -> None: + if not self.has_table(self.table): + return self.run_cmd(f'add table inet {self.table}') + + def create_chain(self) -> None: + if not self.has_chain(self.chain): + return self.run_json_cmd( + self._compose_json( + [ + { + 'add': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain, + 'hook': 'input', + } + } + } + ] + ) + ) + + @property + def chains(self) -> list[dict]: + output = self.run_cmd('list chains') + if output[0] != 0: + raise NftablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['chain']['name'] for record in parsed if 'chain' in record] + + @property + def tables(self) -> list[dict]: + output = self.run_cmd('list tables') + if output[0] != 0: + raise NftablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['table']['name'] for record in parsed if 'table' in record] + + def run_json_cmd(self, cmd: dict) -> tuple: + logger.debug('Nftables json cmd %s', cmd) + with self.plock: + return self.nft.json_cmd(cmd) + + def run_cmd(self, cmd: str) -> tuple: + logger.debug('Nftables cmd %s', cmd) + with self.plock: + return self.nft.cmd(cmd) + + def has_chain(self, chain: str) -> bool: + return chain in self.chains + + def has_table(self, table: str) -> bool: + return table in self.tables + + def add_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + return + expr = self.rule_to_expr(rule) + + json_cmd = self._compose_json( + [ + { + 'add': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + ) + + rc, output, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NftablesCmdFailedError(f'Failed to add allow rule: {error}') + + @classmethod + def rule_to_expr(cls, rule: SChainRule) -> list: + expr = [] + + if rule.first_ip: + if rule.last_ip == rule.first_ip: + expr.append( + { + 'match': { + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'op': '==', + 'right': f'{rule.first_ip}', + } + } + ) + else: + expr.append( + { + 'match': { + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'op': '==', + 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, + } + } + ) + + if rule.port: + expr.append( + { + 'match': { + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'op': '==', + 'right': rule.port, + } + } + ) + + expr.append({'accept': None}) + return expr + + @classmethod + def expr_to_rule(self, expr: list) -> None: + port, first_ip, last_ip = None, None, None + for item in expr: + if 'match' in item: + match = item['match'] + + if match.get('left', {}).get('payload', {}).get('field') == 'dport': + port = match.get('right') + + if match.get('left', {}).get('payload', {}).get('field') == 'saddr': + right = match.get('right') + if isinstance(right, str): + first_ip = right + else: + first_ip, last_ip = right['range'] + + if any([port, first_ip, last_ip]): + return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) + + def remove_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + expr = self.rule_to_expr(rule) + + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if rc != 0: + raise Exception(f'Failed to list rules: {error}') + + current_rules = json.loads(output) + + handle = None + for item in current_rules.get('nftables', []): + if 'rule' in item: + rule_data = item['rule'] + if rule_data.get('expr') == expr: + handle = rule_data.get('handle') + break + + if handle is None: + raise Exception('Rule not found') + + json_cmd = self._compose_json( + [ + { + 'delete': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'handle': handle, + } + } + } + ] + ) + + rc, output, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NftablesCmdFailedError(f'Failed to delete rule: {error}') + + @property # type: ignore + def rules(self) -> Iterable[SChainRule]: + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if output == '': + return [] + + data = json.loads(output) + rules = [] + + for item in data.get('nftables', []): + if 'rule' in item: + plain_rule = item['rule'] + rule = self.expr_to_rule(plain_rule.get('expr', [])) + if rule: + rules.append(rule) + return rules + + def has_rule(self, rule: SChainRule) -> bool: + return rule in self.rules + + @classmethod + def from_ip_network(cls, ip: str) -> str: + return str(ipaddress.ip_network(ip).hosts()[0]) + + @classmethod + def to_ip_network(cls, ip: str) -> str: + return str(ipaddress.ip_network(ip)) diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index 51e8920a8..08bfcd48d 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -23,7 +23,7 @@ from functools import wraps from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar -from .firewall_manager import IptablesSChainFirewallManager +from .firewall_manager import IptablesSChainFirewallManager, NftSchainFirewallManager from .types import ( IFirewallManager, IpRange, @@ -214,3 +214,13 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager: self.base_port, # type: ignore self.base_port + self.ports_per_schain - 1 # type: ignore ) + + +class NftSchainRuleController(SChainRuleController): + @configured_only + def create_firewall_manager(self) -> NftSchainFirewallManager: + return NftSchainFirewallManager( + self.name, + self.base_port, # type: ignore + self.base_port + self.ports_per_schain - 1 # type: ignore + ) diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 737361e18..7bb5ec006 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,7 +25,7 @@ from skale import Skale from .types import IpRange -from .rule_controller import IptablesSChainRuleController +from .rule_controller import IptablesSChainRuleController, NftSchainRuleController logger = logging.getLogger(__name__) @@ -37,6 +37,22 @@ def get_default_rule_controller( own_ip: Optional[str] = None, node_ips: List[str] = [], sync_agent_ranges: Optional[List[IpRange]] = [] +) -> IptablesSChainRuleController: + return get_nftables_rule_controller( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_agent_ranges=sync_agent_ranges + ) + + +def get_iptables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] ) -> IptablesSChainRuleController: sync_agent_ranges = sync_agent_ranges or [] logger.info('Creating rule controller for %s', name) @@ -50,6 +66,25 @@ def get_default_rule_controller( ) +def get_nftables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] +) -> NftSchainRuleController: + sync_agent_ranges = sync_agent_ranges or [] + logger.info('Creating rule controller for %s', name) + logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges) + return NftSchainRuleController( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_ip_ranges=sync_agent_ranges + ) + + def get_sync_agent_ranges(skale: Skale) -> List[IpRange]: sync_agent_ranges = [] rnum = skale.sync_manager.get_ip_ranges_number() diff --git a/tests.Dockerfile b/tests.Dockerfile index b31db00ee..75b72f635 100644 --- a/tests.Dockerfile +++ b/tests.Dockerfile @@ -1,3 +1,7 @@ FROM admin:base -RUN pip3 install --no-cache-dir -r requirements-dev.txt +RUN apt update && apt install -y nftables python3-nftables + +RUN pip3 install -r requirements-dev.txt + +ENV PYTHONPATH=${PYTHONPATH}:/usr/lib/python3/dist-packages/ diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py new file mode 100644 index 000000000..038ec6cc8 --- /dev/null +++ b/tests/firewall/nftables_test.py @@ -0,0 +1,105 @@ +import concurrent.futures +import importlib +import subprocess +import time + +import pytest + +from core.schains.firewall.nftables import NftablesController +from core.schains.firewall.types import SChainRule + + +@pytest.fixture +def nf_test_tables(): + nft = importlib.import_module('nftables').Nftables() + nft.cmd('flush ruleset') + return nft + + +@pytest.fixture +def filter_table(nf_test_tables): + print(nf_test_tables.cmd('add table inet filter')) + + +@pytest.fixture +def custom_chain(nf_test_tables, filter_table): + nf_test_tables.cmd('add chain inet filter test-chain') + return 'test-chain' + + +def test_nftables_controller(custom_chain): + nft_controller = NftablesController(chain='test-chain') + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + rule_b = SChainRule(10001, '3.3.3.3') + nft_controller.add_rule(rule_a) + nft_controller.add_rule(rule_b) + assert nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + rules = list(nft_controller.rules) + assert rules == sorted([rule_b, rule_a]) + nft_controller.remove_rule(rule_a) + assert not nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + nft_controller.remove_rule(rule_b) + assert not nft_controller.has_rule(rule_a) + + +def test_nftables_controller_duplicates(custom_chain): + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + manager = NftablesController(chain='test-chain') + manager.add_rule(rule_a) + rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4') + manager.add_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + assert manager.has_rule(rule_b) + manager.add_rule(rule_b) + assert manager.has_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + manager.remove_rule(rule_b) + assert list(manager.rules) == [ + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ] + + +def add_remove_rule(srule, refresh): + manager = NftablesController() + manager.add_rule(srule) + time.sleep(1) + if not manager.has_rule(srule): + return False + time.sleep(1) + manager.remove_rule(srule) + return True + + +def generate_srules(number=5): + return [ + SChainRule( + 10000 + 1, + f'{i}.{i}.{i}.{i}', f'{i + 1}.{i + 1}.{i + 1}.{i + 1}' + ) + for i in range(1, number * 2, 2) + ] + + +def test_nftables_manager_parallel(custom_chain): + srules = generate_srules(number=12) + + futures = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=12) as executor: + futures = [ + executor.submit(add_remove_rule, srule) + for srule in srules + ] + + for future in concurrent.futures.as_completed(futures): + assert future.result + manager = NftablesController(custom_chain) + time.sleep(10) + assert len(list(manager.rules)) == 0 From 6980a49edf1d479da677b871e6fe64bf6188443a Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 15 Nov 2024 18:38:11 +0000 Subject: [PATCH 02/30] Rename Nftables to NFTables --- core/schains/firewall/firewall_manager.py | 8 ++++---- core/schains/firewall/nftables.py | 18 +++++++++--------- core/schains/firewall/rule_controller.py | 8 ++++---- core/schains/firewall/utils.py | 6 +++--- tests/firewall/nftables_test.py | 13 ++++++------- 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 43c5cb81d..b43f3a223 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -22,7 +22,7 @@ from typing import Iterable, Optional from core.schains.firewall.iptables import IptablesController -from core.schains.firewall.nftables import NftablesController +from core.schains.firewall.nftables import NFTablesController from core.schains.firewall.types import ( IFirewallManager, IHostFirewallController, @@ -91,9 +91,9 @@ def create_host_controller(self) -> IptablesController: return IptablesController() -class NftSchainFirewallManager(SChainFirewallManager): - def create_host_controller(self) -> NftablesController: - nc_controller = NftablesController(chain=self.name) +class NFTSchainFirewallManager(SChainFirewallManager): + def create_host_controller(self) -> NFTablesController: + nc_controller = NFTablesController(chain=self.name) nc_controller.create_table() nc_controller.create_chain() return nc_controller diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 32cc24405..95e677826 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -58,11 +58,11 @@ def is_like_number(value): return True -class NftablesCmdFailedError(Exception): +class NFTablesCmdFailedError(Exception): pass -class NftablesController(IHostFirewallController): +class NFTablesController(IHostFirewallController): plock = multiprocessing.Lock() FAMILY = 'inet' @@ -70,7 +70,7 @@ def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self.table = table self.chain = chain self._nftables = importlib.import_module('nftables') - self.nft = self._nftables.Nftables() + self.nft = self._nftables.NFTables() self.nft.set_json_output(True) def _compose_json(self, commands: list[dict]) -> dict: @@ -105,7 +105,7 @@ def create_chain(self) -> None: def chains(self) -> list[dict]: output = self.run_cmd('list chains') if output[0] != 0: - raise NftablesCmdFailedError(output) + raise NFTablesCmdFailedError(output) parsed = json.loads(output[1])['nftables'] return [record['chain']['name'] for record in parsed if 'chain' in record] @@ -113,17 +113,17 @@ def chains(self) -> list[dict]: def tables(self) -> list[dict]: output = self.run_cmd('list tables') if output[0] != 0: - raise NftablesCmdFailedError(output) + raise NFTablesCmdFailedError(output) parsed = json.loads(output[1])['nftables'] return [record['table']['name'] for record in parsed if 'table' in record] def run_json_cmd(self, cmd: dict) -> tuple: - logger.debug('Nftables json cmd %s', cmd) + logger.debug('NFTables json cmd %s', cmd) with self.plock: return self.nft.json_cmd(cmd) def run_cmd(self, cmd: str) -> tuple: - logger.debug('Nftables cmd %s', cmd) + logger.debug('NFTables cmd %s', cmd) with self.plock: return self.nft.cmd(cmd) @@ -155,7 +155,7 @@ def add_rule(self, rule: SChainRule) -> None: rc, output, error = self.run_json_cmd(json_cmd) if rc != 0: - raise NftablesCmdFailedError(f'Failed to add allow rule: {error}') + raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') @classmethod def rule_to_expr(cls, rule: SChainRule) -> list: @@ -256,7 +256,7 @@ def remove_rule(self, rule: SChainRule) -> None: rc, output, error = self.run_json_cmd(json_cmd) if rc != 0: - raise NftablesCmdFailedError(f'Failed to delete rule: {error}') + raise NFTablesCmdFailedError(f'Failed to delete rule: {error}') @property # type: ignore def rules(self) -> Iterable[SChainRule]: diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index 08bfcd48d..3b63026bb 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -23,7 +23,7 @@ from functools import wraps from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar -from .firewall_manager import IptablesSChainFirewallManager, NftSchainFirewallManager +from .firewall_manager import IptablesSChainFirewallManager, NFTSchainFirewallManager from .types import ( IFirewallManager, IpRange, @@ -216,10 +216,10 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager: ) -class NftSchainRuleController(SChainRuleController): +class NFTSchainRuleController(SChainRuleController): @configured_only - def create_firewall_manager(self) -> NftSchainFirewallManager: - return NftSchainFirewallManager( + def create_firewall_manager(self) -> NFTSchainFirewallManager: + return NFTSchainFirewallManager( self.name, self.base_port, # type: ignore self.base_port + self.ports_per_schain - 1 # type: ignore diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 7bb5ec006..1f94694fd 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,7 +25,7 @@ from skale import Skale from .types import IpRange -from .rule_controller import IptablesSChainRuleController, NftSchainRuleController +from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController logger = logging.getLogger(__name__) @@ -72,11 +72,11 @@ def get_nftables_rule_controller( own_ip: Optional[str] = None, node_ips: List[str] = [], sync_agent_ranges: Optional[List[IpRange]] = [] -) -> NftSchainRuleController: +) -> NFTSchainRuleController: sync_agent_ranges = sync_agent_ranges or [] logger.info('Creating rule controller for %s', name) logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges) - return NftSchainRuleController( + return NFTSchainRuleController( name=name, base_port=base_port, own_ip=own_ip, diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 038ec6cc8..2cfa02058 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -1,17 +1,16 @@ import concurrent.futures import importlib -import subprocess import time import pytest -from core.schains.firewall.nftables import NftablesController +from core.schains.firewall.nftables import NFTablesController from core.schains.firewall.types import SChainRule @pytest.fixture def nf_test_tables(): - nft = importlib.import_module('nftables').Nftables() + nft = importlib.import_module('nftables').NFTables() nft.cmd('flush ruleset') return nft @@ -28,7 +27,7 @@ def custom_chain(nf_test_tables, filter_table): def test_nftables_controller(custom_chain): - nft_controller = NftablesController(chain='test-chain') + nft_controller = NFTablesController(chain='test-chain') rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') rule_b = SChainRule(10001, '3.3.3.3') nft_controller.add_rule(rule_a) @@ -46,7 +45,7 @@ def test_nftables_controller(custom_chain): def test_nftables_controller_duplicates(custom_chain): rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') - manager = NftablesController(chain='test-chain') + manager = NFTablesController(chain='test-chain') manager.add_rule(rule_a) rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4') manager.add_rule(rule_b) @@ -68,7 +67,7 @@ def test_nftables_controller_duplicates(custom_chain): def add_remove_rule(srule, refresh): - manager = NftablesController() + manager = NFTablesController() manager.add_rule(srule) time.sleep(1) if not manager.has_rule(srule): @@ -100,6 +99,6 @@ def test_nftables_manager_parallel(custom_chain): for future in concurrent.futures.as_completed(futures): assert future.result - manager = NftablesController(custom_chain) + manager = NFTablesController(custom_chain) time.sleep(10) assert len(list(manager.rules)) == 0 From 17a9eeade8719dc62fd0badaebfd8c1c4741f68f Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 18 Nov 2024 16:10:57 +0000 Subject: [PATCH 03/30] Fix tests --- tests/firewall/nftables_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 2cfa02058..d11a2d6ff 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -10,7 +10,7 @@ @pytest.fixture def nf_test_tables(): - nft = importlib.import_module('nftables').NFTables() + nft = importlib.import_module('nftables').Nftables() nft.cmd('flush ruleset') return nft From 204673a4038f0aebd9b95e23991812d31b35cab7 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 18 Nov 2024 19:09:47 +0000 Subject: [PATCH 04/30] Fix import --- core/schains/firewall/nftables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 95e677826..5c8808c16 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -70,7 +70,7 @@ def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self.table = table self.chain = chain self._nftables = importlib.import_module('nftables') - self.nft = self._nftables.NFTables() + self.nft = self._nftables.Nftables() self.nft.set_json_output(True) def _compose_json(self, commands: list[dict]) -> dict: From e9cb5064f461a8f4c52667424fbf330cd9462a62 Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 22 Nov 2024 15:46:27 +0000 Subject: [PATCH 05/30] Fix default_rule_controller_test --- core/schains/firewall/__init__.py | 1 + .../firewall/default_rule_controller_test.py | 67 +++++-------------- 2 files changed, 17 insertions(+), 51 deletions(-) diff --git a/core/schains/firewall/__init__.py b/core/schains/firewall/__init__.py index 8edbd1a7c..1bba60b76 100644 --- a/core/schains/firewall/__init__.py +++ b/core/schains/firewall/__init__.py @@ -19,6 +19,7 @@ from .firewall_manager import SChainFirewallManager # noqa from .iptables import IptablesController # noqa +from .nftables import NFTablesController # noqa from .rule_controller import SChainRuleController # noqa from .types import IRuleController # noqa from .utils import get_default_rule_controller # noqa diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index c2473e16f..2032b918c 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -1,23 +1,13 @@ + import mock import pytest import concurrent.futures from skale.schain_config import PORTS_PER_SCHAIN # noqa -from core.schains.firewall import IptablesController +from core.schains.firewall import NFTablesController from core.schains.firewall.utils import get_default_rule_controller from core.schains.firewall.types import IpRange, SkaledPorts -from tests.firewall.iptables_test import get_rules_through_subprocess -from tools.helper import run_cmd - - -@pytest.fixture -def refresh(): - run_cmd(['iptables', '-F']) - try: - yield - finally: - run_cmd(['iptables', '-F']) def test_get_default_rule_controller(): @@ -58,22 +48,6 @@ def sync_rules(*args): return False -def parse_plain_rule(plain_rule): - first_ip, last_ip, port = None, None, None - pr = plain_rule.split() - if '--src-range' in pr: - srange = pr[11] - first_ip, last_ip = srange.split('-') - port = pr[7] - elif '-s' in pr: - first_ip = last_ip = pr[3][:-3] - port = pr[9] - elif '--dport' in pr: - port = pr[7] - - return first_ip, last_ip, int(port) - - def run_concurrent_rc_syncing( node_number, schain_number, @@ -147,55 +121,46 @@ def run_concurrent_rc_syncing( else: assert not r - pr = get_rules_through_subprocess(unique=False)[3:] - rules = [parse_plain_rule(r) for r in pr] + controllers = [NFTablesController(chain=name) for name in schain_names] + rules = [] + for controller in controllers: + rules.extend(controller.rules) + + print([r.port for r in rules]) + print([r.first_ip for r in rules]) - c = IptablesController() # Check that all ip rules are there for ip in node_ips: if ip != own_ip: assert sum( - map(lambda x: x[0] == ip, rules) - ) == 5 * schain_number, ip - assert sum( - map(lambda x: x.first_ip == ip, c.rules) + map(lambda x: x.first_ip == ip, rules) ) == 5 * schain_number, ip # Check that all internal ports rules are there except CATCHUP for p in internal_ports: - assert sum(map(lambda x: x[2] == p, rules)) == node_number - 1, p - assert sum(map(lambda x: x.port == p, c.rules)) == node_number - 1, p + assert sum(map(lambda x: x.port == p, rules)) == node_number - 1, p # Check CATCHUP rules including sync agents rules catchup_e_number = node_number + sync_agent_ranges_number - 1 for p in catchup_ports: - assert sum(map(lambda x: x[2] == p, rules)) == catchup_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == catchup_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == catchup_e_number, p # Check ZMQ rules including sync agents rules zmq_e_number = node_number + sync_agent_ranges_number - 1 for p in zmq_ports: - assert sum(map(lambda x: x[2] == p, rules)) == zmq_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == zmq_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == zmq_e_number, p # Check sync ip ranges rules for r in sync_agent_ranges: assert sum( - map(lambda x: x[0] == r.start_ip, rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x.first_ip == r.start_ip, c.rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x[1] == r.end_ip, rules) + map(lambda x: x.first_ip == r.start_ip, rules) ) == schain_number * 2, ip assert sum( - map(lambda x: x.last_ip == r.end_ip, c.rules) + map(lambda x: x.last_ip == r.end_ip, rules) ) == schain_number * 2, ip for port in public_ports: - assert sum(map(lambda x: x[2] == port, rules)) == 1, port - assert sum(map(lambda x: x.port == port, c.rules)) == 1, port + assert sum(map(lambda x: x.port == port, rules)) == 1, port @pytest.mark.parametrize('attempt', range(5)) From 8e40cc4c5018665a4d42f43fffaf91aadd576242 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 25 Nov 2024 13:33:29 +0000 Subject: [PATCH 06/30] Fix default_rule_controller test --- tests/firewall/default_rule_controller_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index 2032b918c..bd29fb40d 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -180,7 +180,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt, refresh): +def test_concurrent_rc_behavior_with_refresh(attempt): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' From b8d57817004a85b1909b0fa48ae3479bc88e48a7 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 25 Nov 2024 15:09:24 +0000 Subject: [PATCH 07/30] Fix test_concurrent_rc_behavior_with_refresh --- tests/firewall/default_rule_controller_test.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index bd29fb40d..0eb4668a3 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -9,6 +9,17 @@ from core.schains.firewall.utils import get_default_rule_controller from core.schains.firewall.types import IpRange, SkaledPorts +from tools.helper import run_cmd + + +@pytest.fixture +def refresh(): + run_cmd(['nft', 'flush', 'ruleset']) + try: + yield + finally: + run_cmd(['nft', 'flush', 'ruleset']) + def test_get_default_rule_controller(): own_ip = '3.3.3.3' @@ -180,7 +191,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt): +def test_concurrent_rc_behavior_with_refresh(attempt, refresh): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' From 50cd4ce53d961c5f3a04302636600a32d82a74f1 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 25 Nov 2024 18:35:10 +0000 Subject: [PATCH 08/30] Do not raise Exception in nftables --- core/schains/firewall/nftables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 5c8808c16..753015823 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -224,7 +224,7 @@ def remove_rule(self, rule: SChainRule) -> None: output = None rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') if rc != 0: - raise Exception(f'Failed to list rules: {error}') + raise NFTablesCmdFailedError(f'Failed to list rules: {error}') current_rules = json.loads(output) @@ -237,7 +237,7 @@ def remove_rule(self, rule: SChainRule) -> None: break if handle is None: - raise Exception('Rule not found') + raise NFTablesCmdFailedError('Rule not found') json_cmd = self._compose_json( [ From 3ff53f841d4d0f8bca71ee4ce3d3a7a77b201108 Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 13 Dec 2024 20:44:54 +0000 Subject: [PATCH 09/30] Fix chain creation --- core/schains/firewall/firewall_manager.py | 2 +- core/schains/firewall/nftables.py | 101 +++++++++++++++------- 2 files changed, 70 insertions(+), 33 deletions(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index b43f3a223..d16f71574 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -95,5 +95,5 @@ class NFTSchainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> NFTablesController: nc_controller = NFTablesController(chain=self.name) nc_controller.create_table() - nc_controller.create_chain() + nc_controller.create_chain(self.first_port, self.last_port) return nc_controller diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 753015823..36d98543b 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -22,8 +22,7 @@ import importlib import ipaddress import multiprocessing -from functools import wraps -from typing import Callable, Iterable +from typing import Iterable from core.schains.firewall.types import IHostFirewallController, SChainRule @@ -35,29 +34,10 @@ logger = logging.getLogger(__name__) -TABLE = 'filter' +TABLE = 'firewall' CHAIN = 'INPUT' -def refreshed(func: Callable) -> Callable: - @wraps(func) - def wrapper(self, *args, **kwargs): - self.refresh() - return func(self, *args, **kwargs) - - return wrapper - - -def is_like_number(value): - if value is None: - return False - try: - int(value) - except ValueError: - return False - return True - - class NFTablesCmdFailedError(Exception): pass @@ -82,7 +62,43 @@ def create_table(self) -> None: if not self.has_table(self.table): return self.run_cmd(f'add table inet {self.table}') - def create_chain(self) -> None: + def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: + expr = [ + { + "match": { + "left": { + "payload": { + "protocol": "tcp", + "field": "dport" + } + }, + "op": "==", + "right": {'range': [first_port, last_port]} + } + }, + {'counter': None}, + {"drop": None} + ] + + if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'): + cmd = { + 'nftables': [ + { + 'add': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + } + self.run_json_cmd(cmd) + logger.info('Added drop rule for chain %s', self.chain) + + def create_chain(self, first_port: int, last_port: int) -> None: if not self.has_chain(self.chain): return self.run_json_cmd( self._compose_json( @@ -94,12 +110,16 @@ def create_chain(self) -> None: 'table': self.table, 'name': self.chain, 'hook': 'input', + 'type': 'filter', + 'prio': 0, + 'policy': 'accept', } } } ] ) ) + self.add_schain_drop_rule(first_port, last_port) @property def chains(self) -> list[dict]: @@ -141,7 +161,7 @@ def add_rule(self, rule: SChainRule) -> None: json_cmd = self._compose_json( [ { - 'add': { + 'insert': { 'rule': { 'family': self.FAMILY, 'table': self.table, @@ -194,7 +214,7 @@ def rule_to_expr(cls, rule: SChainRule) -> list: } ) - expr.append({'accept': None}) + expr.extend([{'counter': None}, {'accept': None}]) return expr @classmethod @@ -217,6 +237,13 @@ def expr_to_rule(self, expr: list) -> None: if any([port, first_ip, last_ip]): return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) + @classmethod + def expr_equals(cls, expr_a: list[dict], expr_b: list[dict]) -> bool: + for item_a, item_b in zip(sorted(expr_a), sorted(expr_b)): + if 'counter' not in item_a and item_a != item_b: + return False + return True + def remove_rule(self, rule: SChainRule) -> None: if self.has_rule(rule): expr = self.rule_to_expr(rule) @@ -228,11 +255,15 @@ def remove_rule(self, rule: SChainRule) -> None: current_rules = json.loads(output) + logger.info('HERE HERE %s', expr) + logger.info('HERE current rules %s', current_rules) handle = None for item in current_rules.get('nftables', []): if 'rule' in item: rule_data = item['rule'] - if rule_data.get('expr') == expr: + logger.info('HERE HERE 2 %s', rule_data['expr']) + logger.info('HERE HERE 3 %s', expr) + if self.expr_equals(rule_data.get('expr'), expr): handle = rule_data.get('handle') break @@ -260,6 +291,12 @@ def remove_rule(self, rule: SChainRule) -> None: @property # type: ignore def rules(self) -> Iterable[SChainRule]: + return self.get_rules_by_policy(policy='accept') + + def has_rule(self, rule: SChainRule) -> bool: + return rule in self.rules + + def get_rules_by_policy(self, policy: str) -> list[SChainRule]: output = None rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') if output == '': @@ -271,14 +308,14 @@ def rules(self) -> Iterable[SChainRule]: for item in data.get('nftables', []): if 'rule' in item: plain_rule = item['rule'] - rule = self.expr_to_rule(plain_rule.get('expr', [])) - if rule: - rules.append(rule) + expr = plain_rule.get('expr', []) + if {policy: None} in expr: + rule = self.expr_to_rule(expr) + if rule: + rules.append(rule) + logger.debug('Rules for policy %s: %s', policy, rules) return rules - def has_rule(self, rule: SChainRule) -> bool: - return rule in self.rules - @classmethod def from_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip).hosts()[0]) From 1b8e97050f065c5abe427487c1d9b27c57ab81a6 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 16 Dec 2024 13:36:08 +0000 Subject: [PATCH 10/30] Fix tests --- core/schains/firewall/nftables.py | 52 +++++++++++++------------------ tests/firewall/nftables_test.py | 6 ++-- 2 files changed, 24 insertions(+), 34 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 36d98543b..0404a0f21 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -64,20 +64,15 @@ def create_table(self) -> None: def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: expr = [ - { - "match": { - "left": { - "payload": { - "protocol": "tcp", - "field": "dport" + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': {'range': [first_port, last_port]}, } - }, - "op": "==", - "right": {'range': [first_port, last_port]} - } - }, - {'counter': None}, - {"drop": None} + }, + {'counter': None}, + {'drop': None}, ] if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'): @@ -178,7 +173,7 @@ def add_rule(self, rule: SChainRule) -> None: raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') @classmethod - def rule_to_expr(cls, rule: SChainRule) -> list: + def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: expr = [] if rule.first_ip: @@ -186,8 +181,8 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'right': f'{rule.first_ip}', } } @@ -196,8 +191,8 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, } } @@ -207,14 +202,17 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, 'right': rule.port, } } ) - expr.extend([{'counter': None}, {'accept': None}]) + if counter: + expr.append({'counter': None}) + + expr.append({'accept': None}) return expr @classmethod @@ -237,16 +235,9 @@ def expr_to_rule(self, expr: list) -> None: if any([port, first_ip, last_ip]): return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) - @classmethod - def expr_equals(cls, expr_a: list[dict], expr_b: list[dict]) -> bool: - for item_a, item_b in zip(sorted(expr_a), sorted(expr_b)): - if 'counter' not in item_a and item_a != item_b: - return False - return True - def remove_rule(self, rule: SChainRule) -> None: if self.has_rule(rule): - expr = self.rule_to_expr(rule) + expr = self.rule_to_expr(rule, counter=False) output = None rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') @@ -255,15 +246,14 @@ def remove_rule(self, rule: SChainRule) -> None: current_rules = json.loads(output) - logger.info('HERE HERE %s', expr) - logger.info('HERE current rules %s', current_rules) handle = None for item in current_rules.get('nftables', []): if 'rule' in item: rule_data = item['rule'] - logger.info('HERE HERE 2 %s', rule_data['expr']) - logger.info('HERE HERE 3 %s', expr) - if self.expr_equals(rule_data.get('expr'), expr): + rule_expr = list( + filter(lambda statement: 'counter' not in statement, rule_data['expr']) + ) + if expr == rule_expr: handle = rule_data.get('handle') break diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index d11a2d6ff..06dc1d52f 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -17,12 +17,12 @@ def nf_test_tables(): @pytest.fixture def filter_table(nf_test_tables): - print(nf_test_tables.cmd('add table inet filter')) + print(nf_test_tables.cmd('add table inet firewall')) @pytest.fixture def custom_chain(nf_test_tables, filter_table): - nf_test_tables.cmd('add chain inet filter test-chain') + nf_test_tables.cmd('add chain inet firewall test-chain') return 'test-chain' @@ -35,7 +35,7 @@ def test_nftables_controller(custom_chain): assert nft_controller.has_rule(rule_a) assert nft_controller.has_rule(rule_b) rules = list(nft_controller.rules) - assert rules == sorted([rule_b, rule_a]) + assert sorted(rules) == sorted([rule_b, rule_a]), (rules, sorted([rule_b, rule_a])) nft_controller.remove_rule(rule_a) assert not nft_controller.has_rule(rule_a) assert nft_controller.has_rule(rule_b) From 5c16c2d8d9b2f0f450ac90e5402c0333849c256c Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 16 Dec 2024 19:56:32 +0000 Subject: [PATCH 11/30] Fix chain creation --- core/schains/firewall/nftables.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 0404a0f21..669bd0273 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -34,8 +34,9 @@ logger = logging.getLogger(__name__) + TABLE = 'firewall' -CHAIN = 'INPUT' +CHAIN = 'skale' class NFTablesCmdFailedError(Exception): @@ -48,7 +49,7 @@ class NFTablesController(IHostFirewallController): def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self.table = table - self.chain = chain + self.chain = f'skale-{chain}' self._nftables = importlib.import_module('nftables') self.nft = self._nftables.Nftables() self.nft.set_json_output(True) @@ -95,7 +96,8 @@ def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: def create_chain(self, first_port: int, last_port: int) -> None: if not self.has_chain(self.chain): - return self.run_json_cmd( + logger.info('Creating chain %s', self.chain) + self.run_json_cmd( self._compose_json( [ { From 319b806e7bb70e9e23c579707231b50750423ca1 Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 17 Dec 2024 12:10:24 +0000 Subject: [PATCH 12/30] Fix nftables tests --- tests/firewall/nftables_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 06dc1d52f..0ee564ec0 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -22,8 +22,9 @@ def filter_table(nf_test_tables): @pytest.fixture def custom_chain(nf_test_tables, filter_table): - nf_test_tables.cmd('add chain inet firewall test-chain') - return 'test-chain' + name = 'test-chain' + nf_test_tables.cmd('add chain inet firewall skale-{name}') + return name def test_nftables_controller(custom_chain): From d7d09c10d00421fe7af62ef1427488abbd0e78b6 Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 17 Dec 2024 12:31:16 +0000 Subject: [PATCH 13/30] Fix nftables test --- tests/firewall/nftables_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 0ee564ec0..4481265fd 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -23,7 +23,7 @@ def filter_table(nf_test_tables): @pytest.fixture def custom_chain(nf_test_tables, filter_table): name = 'test-chain' - nf_test_tables.cmd('add chain inet firewall skale-{name}') + nf_test_tables.cmd(f'add chain inet firewall skale-{name}') return name From 275c7d54be2fe9acd963752a3857023dd44b0b9c Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 17 Dec 2024 19:54:59 +0000 Subject: [PATCH 14/30] Save rules backup after sync --- core/schains/firewall/firewall_manager.py | 5 +++++ core/schains/firewall/iptables.py | 3 +++ core/schains/firewall/nftables.py | 27 +++++++++++++++++++---- core/schains/firewall/types.py | 4 ++++ tests/utils.py | 3 +++ tools/configs/__init__.py | 2 ++ 6 files changed, 40 insertions(+), 4 deletions(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index d16f71574..1393864bc 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -71,6 +71,11 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: rules_to_remove = actual_rules - expected_rules self.add_rules(rules_to_add) self.remove_rules(rules_to_remove) + self.save_rules() + + def save_rules(self) -> None: + """ Saves rules into persistent storage """ + self.host_controller.save_rules() def add_rules(self, rules: Iterable[SChainRule]) -> None: logger.debug('Adding rules %s', rules) diff --git a/core/schains/firewall/iptables.py b/core/schains/firewall/iptables.py index 1d28c4037..fbe1b55f4 100644 --- a/core/schains/firewall/iptables.py +++ b/core/schains/firewall/iptables.py @@ -139,3 +139,6 @@ def from_ip_network(cls, ip: str) -> str: @classmethod def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) + + def save_rules(self): + raise NotImplementedError('save_rules is not implemented for iptables host controller') diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 669bd0273..57551ce5b 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -18,16 +18,17 @@ # along with this program. If not, see . -import logging import importlib import ipaddress +import json +import logging import multiprocessing -from typing import Iterable +import os +from typing import Iterable, TypeVar from core.schains.firewall.types import IHostFirewallController, SChainRule -from typing import TypeVar -import json +from tools.configs import NFT_CHAIN_BASE_PATH T = TypeVar('T') @@ -315,3 +316,21 @@ def from_ip_network(cls, ip: str) -> str: @classmethod def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) + + def get_plain_chain_rules(self) -> str: + self.nft.set_json_output(False) + output = '' + try: + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if rc != 0: + raise NFTablesCmdFailedError(f"Failed to get table content: {error}") + finally: + self.nft.set_json_output(True) + + return output + + def save_rules(self) -> None: + chain_rules = self.get_plain_chain_rules() + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') + with open(nft_chain_path, 'w') as nft_chain_file: + nft_chain_file.write(chain_rules) diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index 65ba8885d..0d25e6d7f 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -88,6 +88,10 @@ def rules(self) -> Iterable[SChainRule]: # pragma: no cover def has_rule(self, rule: SChainRule) -> bool: # pragma: no cover pass + @abstractmethod + def save_rules(self) -> None: # pragma: no cover + pass + class IFirewallManager(ABC): @property diff --git a/tests/utils.py b/tests/utils.py index dc33bf91b..e7fa56881 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -222,6 +222,9 @@ def rules(self): def has_rule(self, srule): return srule in self._rules + def save_rules(self): + pass + class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self): diff --git a/tools/configs/__init__.py b/tools/configs/__init__.py index 4794de043..e1b0e053f 100644 --- a/tools/configs/__init__.py +++ b/tools/configs/__init__.py @@ -106,3 +106,5 @@ SYNC_NODE = os.getenv('SYNC_NODE') == 'True' DOCKER_NODE_CONFIG_FILEPATH = os.path.join(NODE_DATA_PATH, 'docker.json') + +NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/chains' From 517a3a858bb8c20fd4242252018d431df89b9214 Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 18 Dec 2024 12:03:19 +0000 Subject: [PATCH 15/30] Fix tests --- .../firewall/default_rule_controller_test.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index 0eb4668a3..c21489a8f 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -1,7 +1,10 @@ +import concurrent.futures import mock +import os +import shutil + import pytest -import concurrent.futures from skale.schain_config import PORTS_PER_SCHAIN # noqa @@ -21,7 +24,17 @@ def refresh(): run_cmd(['nft', 'flush', 'ruleset']) -def test_get_default_rule_controller(): +@pytest.fixture() +def nft_chain_folder(): + path = '/etc/nft.conf.d/chains' + try: + os.makedirs(path) + yield path + finally: + shutil.rmtree(path) + + +def test_get_default_rule_controller(nft_chain_folder): own_ip = '3.3.3.3' node_ips = ['1.1.1.1', '2.2.2.2', '3.3.3.3', '4.4.4.4'] base_port = 10064 @@ -175,7 +188,7 @@ def run_concurrent_rc_syncing( @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_no_refresh(attempt): +def test_concurrent_rc_behavior_no_refresh(attempt, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' @@ -191,7 +204,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt, refresh): +def test_concurrent_rc_behavior_with_refresh(attempt, refresh, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' From 0aacca37e28b5ece4e282c04b517adede9f727ce Mon Sep 17 00:00:00 2001 From: badrogger Date: Thu, 26 Dec 2024 12:59:41 +0000 Subject: [PATCH 16/30] Cleanup nftables chain after schain removal --- core/schains/firewall/firewall_manager.py | 1 + core/schains/firewall/iptables.py | 5 ++++- core/schains/firewall/nftables.py | 22 ++++++++++++++++++++++ core/schains/firewall/types.py | 4 ++++ 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 1393864bc..5b2f76407 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -89,6 +89,7 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None: def flush(self) -> None: self.remove_rules(self.rules) + self.host_controller.cleanup() class IptablesSChainFirewallManager(SChainFirewallManager): diff --git a/core/schains/firewall/iptables.py b/core/schains/firewall/iptables.py index fbe1b55f4..589250d68 100644 --- a/core/schains/firewall/iptables.py +++ b/core/schains/firewall/iptables.py @@ -140,5 +140,8 @@ def from_ip_network(cls, ip: str) -> str: def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) - def save_rules(self): + def save_rules(self) -> None: raise NotImplementedError('save_rules is not implemented for iptables host controller') + + def cleanup(self) -> None: + raise NotImplementedError('cleanup is not implemented for iptables host controller') diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 57551ce5b..d0ff2dd2a 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -119,6 +119,25 @@ def create_chain(self, first_port: int, last_port: int) -> None: ) self.add_schain_drop_rule(first_port, last_port) + def delete_chain(self) -> None: + if self.has_chain(self.chain): + logger.info('Removing chain %s', self.chain) + self.run_json_cmd( + self._compose_json( + [ + { + 'delete': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain + } + } + } + ] + ) + ) + @property def chains(self) -> list[dict]: output = self.run_cmd('list chains') @@ -334,3 +353,6 @@ def save_rules(self) -> None: nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') with open(nft_chain_path, 'w') as nft_chain_file: nft_chain_file.write(chain_rules) + + def cleanup(self) -> None: + self.delete_chain() diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index 0d25e6d7f..0062cccec 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -92,6 +92,10 @@ def has_rule(self, rule: SChainRule) -> bool: # pragma: no cover def save_rules(self) -> None: # pragma: no cover pass + @abstractmethod + def cleanup(self) -> None: # pragma: no cover + pass + class IFirewallManager(ABC): @property From 80f773b0eb3b1948aba59aa6aa8f4017c6ac1b19 Mon Sep 17 00:00:00 2001 From: badrogger Date: Thu, 26 Dec 2024 15:28:26 +0000 Subject: [PATCH 17/30] Fix tests --- tests/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index e7fa56881..014c054f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -225,6 +225,9 @@ def has_rule(self, srule): def save_rules(self): pass + def cleanup(self): + pass + class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self): From 71e1389012d8b34cda2f5d20069fef4f2a6b6812 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 30 Dec 2024 12:48:27 +0000 Subject: [PATCH 18/30] Small improvements --- core/schains/firewall/nftables.py | 143 ++++++++++++++---------------- 1 file changed, 66 insertions(+), 77 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index d0ff2dd2a..00d81516e 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -19,19 +19,16 @@ import importlib -import ipaddress import json import logging import multiprocessing import os -from typing import Iterable, TypeVar +from typing import Iterable from core.schains.firewall.types import IHostFirewallController, SChainRule from tools.configs import NFT_CHAIN_BASE_PATH -T = TypeVar('T') - logger = logging.getLogger(__name__) @@ -55,6 +52,69 @@ def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self.nft = self._nftables.Nftables() self.nft.set_json_output(True) + @classmethod + def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: + expr = [] + + if rule.first_ip: + if rule.last_ip == rule.first_ip: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'right': f'{rule.first_ip}', + } + } + ) + else: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, + } + } + ) + + if rule.port: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': rule.port, + } + } + ) + + if counter: + expr.append({'counter': None}) + + expr.append({'accept': None}) + return expr + + @classmethod + def expr_to_rule(self, expr: list) -> None: + port, first_ip, last_ip = None, None, None + for item in expr: + if 'match' in item: + match = item['match'] + + if match.get('left', {}).get('payload', {}).get('field') == 'dport': + port = match.get('right') + + if match.get('left', {}).get('payload', {}).get('field') == 'saddr': + right = match.get('right') + if isinstance(right, str): + first_ip = right + else: + first_ip, last_ip = right['range'] + + if any([port, first_ip, last_ip]): + return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) + def _compose_json(self, commands: list[dict]) -> dict: json_cmd = {'nftables': commands} self.nft.json_validate(json_cmd) @@ -190,73 +250,10 @@ def add_rule(self, rule: SChainRule) -> None: ] ) - rc, output, error = self.run_json_cmd(json_cmd) + rc, _, error = self.run_json_cmd(json_cmd) if rc != 0: raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') - @classmethod - def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: - expr = [] - - if rule.first_ip: - if rule.last_ip == rule.first_ip: - expr.append( - { - 'match': { - 'op': '==', - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, - 'right': f'{rule.first_ip}', - } - } - ) - else: - expr.append( - { - 'match': { - 'op': '==', - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, - 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, - } - } - ) - - if rule.port: - expr.append( - { - 'match': { - 'op': '==', - 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, - 'right': rule.port, - } - } - ) - - if counter: - expr.append({'counter': None}) - - expr.append({'accept': None}) - return expr - - @classmethod - def expr_to_rule(self, expr: list) -> None: - port, first_ip, last_ip = None, None, None - for item in expr: - if 'match' in item: - match = item['match'] - - if match.get('left', {}).get('payload', {}).get('field') == 'dport': - port = match.get('right') - - if match.get('left', {}).get('payload', {}).get('field') == 'saddr': - right = match.get('right') - if isinstance(right, str): - first_ip = right - else: - first_ip, last_ip = right['range'] - - if any([port, first_ip, last_ip]): - return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) - def remove_rule(self, rule: SChainRule) -> None: if self.has_rule(rule): expr = self.rule_to_expr(rule, counter=False) @@ -297,7 +294,7 @@ def remove_rule(self, rule: SChainRule) -> None: ] ) - rc, output, error = self.run_json_cmd(json_cmd) + rc, _, error = self.run_json_cmd(json_cmd) if rc != 0: raise NFTablesCmdFailedError(f'Failed to delete rule: {error}') @@ -328,14 +325,6 @@ def get_rules_by_policy(self, policy: str) -> list[SChainRule]: logger.debug('Rules for policy %s: %s', policy, rules) return rules - @classmethod - def from_ip_network(cls, ip: str) -> str: - return str(ipaddress.ip_network(ip).hosts()[0]) - - @classmethod - def to_ip_network(cls, ip: str) -> str: - return str(ipaddress.ip_network(ip)) - def get_plain_chain_rules(self) -> str: self.nft.set_json_output(False) output = '' From d4356b7d342569749b65524fbd09c557dc6d5429 Mon Sep 17 00:00:00 2001 From: badrogger Date: Thu, 2 Jan 2025 11:43:04 +0000 Subject: [PATCH 19/30] Bump version to 2.9.0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 834f26295..c8e38b614 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.8.0 +2.9.0 From c657e0c64ef7e73eb3de833b012322dc528426bc Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 14 Jan 2025 16:47:17 +0000 Subject: [PATCH 20/30] Fix nftables chain cleanup --- core/schains/firewall/nftables.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 00d81516e..486846cf7 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -23,6 +23,7 @@ import logging import multiprocessing import os +import shutil from typing import Iterable from core.schains.firewall.types import IHostFirewallController, SChainRule @@ -343,5 +344,10 @@ def save_rules(self) -> None: with open(nft_chain_path, 'w') as nft_chain_file: nft_chain_file.write(chain_rules) + def remove_saved_rules(self) -> None: + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') + shutil.rmtree(nft_chain_path) + def cleanup(self) -> None: self.delete_chain() + self.remove_saved_rules() From 934cee45748355a327c34dc405e49cff7607c97f Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 15 Jan 2025 13:37:51 +0000 Subject: [PATCH 21/30] Make cleaner to remove schain firewall rules config --- core/schains/firewall/nftables.py | 8 +++++-- tests/firewall/nftables_test.py | 36 ++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 486846cf7..b7b312606 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -52,6 +52,7 @@ def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self._nftables = importlib.import_module('nftables') self.nft = self._nftables.Nftables() self.nft.set_json_output(True) + self.nft.set_stateless_output(True) @classmethod def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: @@ -179,6 +180,7 @@ def create_chain(self, first_port: int, last_port: int) -> None: ) ) self.add_schain_drop_rule(first_port, last_port) + self.save_rules() def delete_chain(self) -> None: if self.has_chain(self.chain): @@ -330,7 +332,9 @@ def get_plain_chain_rules(self) -> str: self.nft.set_json_output(False) output = '' try: - rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + rc, output, error = self.run_cmd( + f'list chain {self.FAMILY} {self.table} {self.chain}' + ) if rc != 0: raise NFTablesCmdFailedError(f"Failed to get table content: {error}") finally: @@ -346,7 +350,7 @@ def save_rules(self) -> None: def remove_saved_rules(self) -> None: nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') - shutil.rmtree(nft_chain_path) + os.remove(nft_chain_path) def cleanup(self) -> None: self.delete_chain() diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 4481265fd..c030aa260 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -1,11 +1,14 @@ import concurrent.futures import importlib +import os +import shutil import time import pytest -from core.schains.firewall.nftables import NFTablesController +from core.schains.firewall.nftables import NFTablesController, NFT_CHAIN_BASE_PATH from core.schains.firewall.types import SChainRule +from tools.helper import run_cmd @pytest.fixture @@ -15,6 +18,16 @@ def nf_test_tables(): return nft +@pytest.fixture() +def nft_chain_folder(): + path = '/etc/nft.conf.d/chains' + try: + os.makedirs(path) + yield path + finally: + shutil.rmtree(path) + + @pytest.fixture def filter_table(nf_test_tables): print(nf_test_tables.cmd('add table inet firewall')) @@ -67,6 +80,27 @@ def test_nftables_controller_duplicates(custom_chain): ] +def test_create_delete_chain(filter_table, nft_chain_folder): + chain_name = 'test-chain' + + output = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + output == 'table inet firewall {\n}\n' + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf') + assert not os.path.isfile(nft_chain_path) + + manager = NFTablesController(chain=chain_name) + manager.create_chain(first_port=10000, last_port=10063) + + chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + assert chains == 'table inet firewall {\n\tchain skale-test-chain {\n\t\ttype filter hook input priority filter; policy accept;\n\t}\n}\n' # noqa + assert os.path.isfile(nft_chain_path) + + manager.cleanup() + chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + assert chains == 'table inet firewall {\n}\n' + assert not os.path.isfile(nft_chain_path) + + def add_remove_rule(srule, refresh): manager = NFTablesController() manager.add_rule(srule) From 4a1c830892f29023642bd866da4dc7d359e5621f Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 15 Jan 2025 13:38:53 +0000 Subject: [PATCH 22/30] Fix linter --- core/schains/firewall/nftables.py | 1 - 1 file changed, 1 deletion(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index b7b312606..f333dd4fb 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -23,7 +23,6 @@ import logging import multiprocessing import os -import shutil from typing import Iterable from core.schains.firewall.types import IHostFirewallController, SChainRule From 218ede8ec932d44217fc527444d04a2be4ea1d16 Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 15 Jan 2025 17:53:57 +0000 Subject: [PATCH 23/30] Reorder operations --- core/schains/firewall/nftables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index f333dd4fb..d94d9162d 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -352,5 +352,5 @@ def remove_saved_rules(self) -> None: os.remove(nft_chain_path) def cleanup(self) -> None: - self.delete_chain() self.remove_saved_rules() + self.delete_chain() From f4361cb9f4294e8756f8efb208e7946cbdfc7623 Mon Sep 17 00:00:00 2001 From: badrogger Date: Sat, 18 Jan 2025 16:03:32 +0000 Subject: [PATCH 24/30] Make sure rules persistent and chain configured --- core/schains/checks.py | 18 +++++++- core/schains/firewall/firewall_manager.py | 8 ++++ core/schains/firewall/nftables.py | 44 +++++++++++++++---- core/schains/firewall/rule_controller.py | 16 +++++++ core/schains/firewall/types.py | 8 ++++ tests/conftest.py | 10 +++++ .../firewall/default_rule_controller_test.py | 12 ----- tests/firewall/nftables_test.py | 11 ----- tests/utils.py | 6 +++ tools/configs/__init__.py | 2 +- 10 files changed, 100 insertions(+), 35 deletions(-) diff --git a/core/schains/checks.py b/core/schains/checks.py index 8beb26739..c56a8d720 100644 --- a/core/schains/checks.py +++ b/core/schains/checks.py @@ -301,6 +301,12 @@ def volume(self) -> CheckRes: @property def firewall_rules(self) -> CheckRes: """Checks that firewall rules are set correctly""" + data = { + 'config': False, + 'inited': False, + 'rules': False, + 'persistant': False, + } if self.config: conf = self.cfm.skaled_config base_port = get_base_port_from_config(conf) @@ -311,8 +317,16 @@ def firewall_rules(self) -> CheckRes: base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges ) logger.debug(f'Rule controller {self.rc.expected_rules()}') - return CheckRes(self.rc.is_rules_synced()) - return CheckRes(False) + data = { + 'config': True, + 'inited': self.rc.is_inited(), + 'rules': self.rc.is_rules_synced(), + 'persistent': self.rc.is_persistent(), + } + logger.debug('Firewall rules check: %s', data) + status = all(data.values()) + return CheckRes(status=status, data=data) + return CheckRes(status=False, data=data) @property def skaled_container(self) -> CheckRes: diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 5b2f76407..6c9f37689 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -103,3 +103,11 @@ def create_host_controller(self) -> NFTablesController: nc_controller.create_table() nc_controller.create_chain(self.first_port, self.last_port) return nc_controller + + def rules_saved(self) -> bool: + saved = self.host_controller.get_saved_rules() + return saved == self.host_controller.get_rules() + + def base_config_applied(self) -> bool: + return self.host_controller.has_chain(self.host_controller.chain) and \ + self.host_controller.has_drop_rule(self.first_port, self.last_port) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index d94d9162d..04f2c7535 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -# # This file is part of SKALE Admin # # Copyright (C) 2024 SKALE Labs @@ -125,7 +124,7 @@ def create_table(self) -> None: if not self.has_table(self.table): return self.run_cmd(f'add table inet {self.table}') - def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: + def has_drop_rule(self, first_port: int, last_port: int) -> bool: expr = [ { 'match': { @@ -138,7 +137,22 @@ def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: {'drop': None}, ] - if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'): + return self.expr_to_rule(expr) in self.get_rules_by_policy(policy='drop') + + def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: + if not self.has_drop_rule(first_port, last_port): + expr = [ + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': {'range': [first_port, last_port]}, + } + }, + {'counter': None}, + {'drop': None}, + ] + cmd = { 'nftables': [ { @@ -178,8 +192,8 @@ def create_chain(self, first_port: int, last_port: int) -> None: ] ) ) - self.add_schain_drop_rule(first_port, last_port) - self.save_rules() + self.add_schain_drop_rule(first_port, last_port) + self.save_rules() def delete_chain(self) -> None: if self.has_chain(self.chain): @@ -339,17 +353,29 @@ def get_plain_chain_rules(self) -> str: finally: self.nft.set_json_output(True) + # cleanup table header + output = '\n'.join(output.split('\n')[2:-1]) + return output + @property + def nft_chain_path(self) -> str: + return os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') + def save_rules(self) -> None: + logger.info('Saving the firewall rules for chain %s', self.chain) chain_rules = self.get_plain_chain_rules() - nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') - with open(nft_chain_path, 'w') as nft_chain_file: + with open(self.nft_chain_path, 'w') as nft_chain_file: nft_chain_file.write(chain_rules) + def get_saved_rules(self) -> str: + if not os.path.isfile(self.nft_chain_path): + return '' + with open(self.nft_chain_path, 'r') as nft_chain_file: + return nft_chain_file.read() + def remove_saved_rules(self) -> None: - nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') - os.remove(nft_chain_path) + os.remove(self.nft_chain_path) def cleanup(self) -> None: self.remove_saved_rules() diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index 3b63026bb..a109d30b5 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -215,6 +215,14 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager: self.base_port + self.ports_per_schain - 1 # type: ignore ) + @configured_only + def is_persistent(self) -> bool: + return True + + @configured_only + def is_inited(self) -> bool: + return True + class NFTSchainRuleController(SChainRuleController): @configured_only @@ -224,3 +232,11 @@ def create_firewall_manager(self) -> NFTSchainFirewallManager: self.base_port, # type: ignore self.base_port + self.ports_per_schain - 1 # type: ignore ) + + @configured_only + def is_persistent(self) -> bool: + return self.firewall_manager.rules_saved() + + @configured_only + def is_inited(self) -> bool: + return self.firewall_manager.base_config_applied() diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index 0062cccec..ecb076c66 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -139,3 +139,11 @@ def sync(self) -> None: # pragma: no cover @abstractmethod def cleanup(self) -> None: # pragma: no cover pass + + @abstractmethod + def is_persistent(self) -> bool: # pragma: no cover + pass + + @abstractmethod + def is_inited(self) -> bool: # pragma: no cover + pass diff --git a/tests/conftest.py b/tests/conftest.py index 973a375e7..ac4b92eab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -631,3 +631,13 @@ def ncli_status(_schain_name): yield init_node_cli_status(_schain_name) finally: shutil.rmtree(schain_dir_path, ignore_errors=True) + + +@pytest.fixture() +def nft_chain_folder(): + path = '/etc/nft.conf.d/skale/chains' + try: + os.makedirs(path) + yield path + finally: + shutil.rmtree(path) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index c21489a8f..cadd5f5dd 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -1,8 +1,6 @@ import concurrent.futures import mock -import os -import shutil import pytest @@ -24,16 +22,6 @@ def refresh(): run_cmd(['nft', 'flush', 'ruleset']) -@pytest.fixture() -def nft_chain_folder(): - path = '/etc/nft.conf.d/chains' - try: - os.makedirs(path) - yield path - finally: - shutil.rmtree(path) - - def test_get_default_rule_controller(nft_chain_folder): own_ip = '3.3.3.3' node_ips = ['1.1.1.1', '2.2.2.2', '3.3.3.3', '4.4.4.4'] diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index c030aa260..9f7f8c63e 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -1,7 +1,6 @@ import concurrent.futures import importlib import os -import shutil import time import pytest @@ -18,16 +17,6 @@ def nf_test_tables(): return nft -@pytest.fixture() -def nft_chain_folder(): - path = '/etc/nft.conf.d/chains' - try: - os.makedirs(path) - yield path - finally: - shutil.rmtree(path) - - @pytest.fixture def filter_table(nf_test_tables): print(nf_test_tables.cmd('add table inet firewall')) diff --git a/tests/utils.py b/tests/utils.py index 014c054f8..06b7e5150 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -242,6 +242,12 @@ def create_firewall_manager(self): self.base_port + self.ports_per_schain ) + def is_persistent(self) -> bool: + return True + + def is_inited(self) -> bool: + return True + def get_test_rule_controller( name, diff --git a/tools/configs/__init__.py b/tools/configs/__init__.py index e1b0e053f..ec04f6175 100644 --- a/tools/configs/__init__.py +++ b/tools/configs/__init__.py @@ -107,4 +107,4 @@ DOCKER_NODE_CONFIG_FILEPATH = os.path.join(NODE_DATA_PATH, 'docker.json') -NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/chains' +NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/skale/chains' From 8bacef041dc7eb0438c2750b41b52fe27cc3027d Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 21 Jan 2025 12:09:12 +0000 Subject: [PATCH 25/30] Save firewall rules properly --- core/schains/firewall/firewall_manager.py | 2 +- core/schains/firewall/nftables.py | 13 +++++++++++-- tests/firewall/nftables_test.py | 16 ++++++++++++++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 6c9f37689..6b33ab3bd 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -106,7 +106,7 @@ def create_host_controller(self) -> NFTablesController: def rules_saved(self) -> bool: saved = self.host_controller.get_saved_rules() - return saved == self.host_controller.get_rules() + return saved == self.host_controller.get_plain_chain_rules() def base_config_applied(self) -> bool: return self.host_controller.has_chain(self.host_controller.chain) and \ diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 04f2c7535..82addc751 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -353,9 +353,18 @@ def get_plain_chain_rules(self) -> str: finally: self.nft.set_json_output(True) + lines = output.split('\n') # cleanup table header - output = '\n'.join(output.split('\n')[2:-1]) - + if lines[-1] == '': + lines = lines[1:-2] + else: + lines = lines[1:-1] + + # remove leading tab + lines = list(map(lambda line: line[1:], lines)) + # Adding new line at the end to prevent validation failure + lines.append('') + output = '\n'.join(lines) return output @property diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 9f7f8c63e..cee70eed5 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -90,6 +90,22 @@ def test_create_delete_chain(filter_table, nft_chain_folder): assert not os.path.isfile(nft_chain_path) +def test_saved_rules(filter_table, nft_chain_folder): + chain_name = 'test-chain' + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf') + + manager = NFTablesController(chain=chain_name) + assert not os.path.isfile(nft_chain_path) + manager.create_chain(first_port=10000, last_port=10063) + assert os.path.isfile(nft_chain_path) + assert manager.get_saved_rules() == 'chain skale-test-chain {\n\ttype filter hook input priority filter; policy accept;\n\ttcp dport 10000-10063 counter drop\n}\n' # noqa + + assert os.path.isfile(nft_chain_path) + + manager.remove_saved_rules() + assert not os.path.isfile(nft_chain_path) + + def add_remove_rule(srule, refresh): manager = NFTablesController() manager.add_rule(srule) From 6ce86c086d3e3fa034696e752e5ec07fc61a6022 Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 21 Jan 2025 20:12:46 +0000 Subject: [PATCH 26/30] Fix cleaner procedure --- core/schains/checks.py | 2 -- core/schains/cleaner.py | 30 ++++++++++++++++++++++++++---- core/schains/firewall/nftables.py | 3 ++- tools/configs/__init__.py | 1 + 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/core/schains/checks.py b/core/schains/checks.py index c56a8d720..0a42bbdbb 100644 --- a/core/schains/checks.py +++ b/core/schains/checks.py @@ -302,7 +302,6 @@ def volume(self) -> CheckRes: def firewall_rules(self) -> CheckRes: """Checks that firewall rules are set correctly""" data = { - 'config': False, 'inited': False, 'rules': False, 'persistant': False, @@ -318,7 +317,6 @@ def firewall_rules(self) -> CheckRes: ) logger.debug(f'Rule controller {self.rc.expected_rules()}') data = { - 'config': True, 'inited': self.rc.is_inited(), 'rules': self.rc.is_rules_synced(), 'persistent': self.rc.is_persistent(), diff --git a/core/schains/cleaner.py b/core/schains/cleaner.py index 7fd291efd..0511e4e70 100644 --- a/core/schains/cleaner.py +++ b/core/schains/cleaner.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import glob import logging import os import shutil @@ -43,7 +44,11 @@ from core.schains.types import ContainerType from core.schains.firewall.utils import get_sync_agent_ranges -from tools.configs import SGX_CERTIFICATES_FOLDER, SYNC_NODE +from tools.configs import ( + NFT_CHAIN_CONFIG_WILDCARD, + SGX_CERTIFICATES_FOLDER, + SYNC_NODE +) from tools.configs.schains import SCHAINS_DIR_PATH from tools.configs.containers import SCHAIN_CONTAINER, IMA_CONTAINER, SCHAIN_STOP_TIMEOUT from tools.docker_utils import DockerUtils @@ -136,18 +141,34 @@ def get_schains_with_containers(dutils=None): ] +def get_schains_firewall_configs() -> list: + return list(map(lambda path: os.path.basename(path), glob.glob(NFT_CHAIN_CONFIG_WILDCARD))) + + def get_schains_on_node(dutils=None): dutils = dutils or DockerUtils() schains_with_dirs = os.listdir(SCHAINS_DIR_PATH) schains_with_container = get_schains_with_containers(dutils) schains_active_records = get_schains_names() + schains_firewall_configs = list( + map(lambda name: name.removeprefix('skale-'), + get_schains_firewall_configs()) + ) logger.info( - 'dirs %s, containers: %s, records: %s', + 'dirs %s, containers: %s, records: %s, firewall configs: %s', schains_with_dirs, schains_with_container, - schains_active_records + schains_active_records, + schains_firewall_configs + ) + return sorted( + merged_unique( + schains_with_dirs, + schains_with_container, + schains_active_records, + schains_firewall_configs + ) ) - return sorted(merged_unique(schains_with_dirs, schains_with_container, schains_active_records)) def schain_names_to_ids(skale, schain_names): @@ -268,6 +289,7 @@ def cleanup_schain( ranges = estate.ranges rc.configure(base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges) rc.cleanup() + if estate is not None and estate.ima_linked: if check_status.get('ima_container', False) or is_exited( schain_name, container_type=ContainerType.ima, dutils=dutils diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 82addc751..a2977c4b2 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -384,7 +384,8 @@ def get_saved_rules(self) -> str: return nft_chain_file.read() def remove_saved_rules(self) -> None: - os.remove(self.nft_chain_path) + if os.isfile(self.nft_chain_path): + os.remove(self.nft_chain_path) def cleanup(self) -> None: self.remove_saved_rules() diff --git a/tools/configs/__init__.py b/tools/configs/__init__.py index ec04f6175..1e2fd423b 100644 --- a/tools/configs/__init__.py +++ b/tools/configs/__init__.py @@ -108,3 +108,4 @@ DOCKER_NODE_CONFIG_FILEPATH = os.path.join(NODE_DATA_PATH, 'docker.json') NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/skale/chains' +NFT_CHAIN_CONFIG_WILDCARD = os.path.join(NFT_CHAIN_BASE_PATH, '*') From 241ab43f0b392d146ff11c1209fa01e5423679bd Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 21 Jan 2025 21:30:40 +0000 Subject: [PATCH 27/30] Fix glob --- core/schains/cleaner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/schains/cleaner.py b/core/schains/cleaner.py index 0511e4e70..9d1bd1f4f 100644 --- a/core/schains/cleaner.py +++ b/core/schains/cleaner.py @@ -22,6 +22,7 @@ import os import shutil from multiprocessing import Process +from pathlib import Path from typing import Optional from sgx import SgxClient @@ -142,7 +143,7 @@ def get_schains_with_containers(dutils=None): def get_schains_firewall_configs() -> list: - return list(map(lambda path: os.path.basename(path), glob.glob(NFT_CHAIN_CONFIG_WILDCARD))) + return list(map(lambda path: Path(path).stem, glob.glob(NFT_CHAIN_CONFIG_WILDCARD))) def get_schains_on_node(dutils=None): @@ -279,7 +280,7 @@ def cleanup_schain( remove_schain_container(schain_name, dutils=dutils) if check_status['volume']: remove_schain_volume(schain_name, dutils=dutils) - if check_status['firewall_rules']: + if any(checks.firewall_rules.data): conf = ConfigFileManager(schain_name).skaled_config base_port = get_base_port_from_config(conf) own_ip = get_own_ip_from_config(conf) From 9e33f8278e0df0ae598563771d0a327e538a70f4 Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 22 Jan 2025 13:58:55 +0000 Subject: [PATCH 28/30] Handle firewall related cleaner failure --- core/schains/cleaner.py | 25 +++++++---------------- core/schains/firewall/firewall_manager.py | 11 ++++++---- core/schains/firewall/nftables.py | 5 ++--- core/schains/firewall/rule_controller.py | 10 ++++++--- core/schains/firewall/types.py | 2 +- core/schains/firewall/utils.py | 7 +++++++ tests/firewall/firewall_manager_test.py | 4 ++-- tests/firewall/nftables_test.py | 19 ++++++++++++++++- tests/schains/cleaner_test.py | 5 ++++- tests/utils.py | 6 ++++++ 10 files changed, 61 insertions(+), 33 deletions(-) diff --git a/core/schains/cleaner.py b/core/schains/cleaner.py index 9d1bd1f4f..881a2e633 100644 --- a/core/schains/cleaner.py +++ b/core/schains/cleaner.py @@ -30,15 +30,9 @@ from core.node import get_current_nodes, get_skale_node_version from core.schains.checks import SChainChecks -from core.schains.config.file_manager import ConfigFileManager from core.schains.config.directory import schain_config_dir from core.schains.dkg.utils import get_secret_key_share_filepath -from core.schains.firewall.utils import get_default_rule_controller -from core.schains.config.helper import ( - get_base_port_from_config, - get_node_ips_from_config, - get_own_ip_from_config, -) +from core.schains.firewall.utils import cleanup_firewall_for_schain, get_default_rule_controller from core.schains.process import ProcessReport, terminate_process from core.schains.runner import get_container_name, is_exited from core.schains.external_config import ExternalConfig @@ -152,8 +146,10 @@ def get_schains_on_node(dutils=None): schains_with_container = get_schains_with_containers(dutils) schains_active_records = get_schains_names() schains_firewall_configs = list( - map(lambda name: name.removeprefix('skale-'), - get_schains_firewall_configs()) + map( + lambda name: name.removeprefix('skale-'), + get_schains_firewall_configs() + ) ) logger.info( 'dirs %s, containers: %s, records: %s, firewall configs: %s', @@ -281,15 +277,8 @@ def cleanup_schain( if check_status['volume']: remove_schain_volume(schain_name, dutils=dutils) if any(checks.firewall_rules.data): - conf = ConfigFileManager(schain_name).skaled_config - base_port = get_base_port_from_config(conf) - own_ip = get_own_ip_from_config(conf) - node_ips = get_node_ips_from_config(conf) - ranges = [] - if estate is not None: - ranges = estate.ranges - rc.configure(base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges) - rc.cleanup() + logger.info('Cleaning firewall for %s', schain_name) + cleanup_firewall_for_schain(schain_name) if estate is not None and estate.ima_linked: if check_status.get('ima_container', False) or is_exited( diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 6b33ab3bd..f9d1bad2b 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -87,15 +87,14 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None: for rule in rules: self.host_controller.remove_rule(rule) - def flush(self) -> None: - self.remove_rules(self.rules) - self.host_controller.cleanup() - class IptablesSChainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> IptablesController: return IptablesController() + def cleanup(self) -> None: + self.remove_rules(self.rules) + class NFTSchainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> NFTablesController: @@ -111,3 +110,7 @@ def rules_saved(self) -> bool: def base_config_applied(self) -> bool: return self.host_controller.has_chain(self.host_controller.chain) and \ self.host_controller.has_drop_rule(self.first_port, self.last_port) + + def cleanup(self) -> None: + self.host_controller.cleanup() + self.host_controller.remove_saved_rules() diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index a2977c4b2..e2a56e39f 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -44,7 +44,7 @@ class NFTablesController(IHostFirewallController): plock = multiprocessing.Lock() FAMILY = 'inet' - def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: + def __init__(self, chain: str, table: str = TABLE) -> None: self.table = table self.chain = f'skale-{chain}' self._nftables = importlib.import_module('nftables') @@ -384,9 +384,8 @@ def get_saved_rules(self) -> str: return nft_chain_file.read() def remove_saved_rules(self) -> None: - if os.isfile(self.nft_chain_path): + if os.path.isfile(self.nft_chain_path): os.remove(self.nft_chain_path) def cleanup(self) -> None: - self.remove_saved_rules() self.delete_chain() diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index a109d30b5..686205260 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -202,9 +202,6 @@ def sync(self) -> None: logger.debug('Syncing firewall rules with %s', erules) self.firewall_manager.update_rules(erules) - def cleanup(self) -> None: - self.firewall_manager.flush() - class IptablesSChainRuleController(SChainRuleController): @configured_only @@ -223,6 +220,10 @@ def is_persistent(self) -> bool: def is_inited(self) -> bool: return True + @configured_only + def cleanup(self) -> None: + self.firewall_manager.cleanup() + class NFTSchainRuleController(SChainRuleController): @configured_only @@ -240,3 +241,6 @@ def is_persistent(self) -> bool: @configured_only def is_inited(self) -> bool: return self.firewall_manager.base_config_applied() + + def cleanup(self) -> None: + self.firewall_manager.cleanup() diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index ecb076c66..c30bfc11b 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -108,7 +108,7 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: # pragma: no cover pass @abstractmethod - def flush(self) -> None: # pragma: no cover # noqa + def cleanup(self) -> None: # pragma: no cover # noqa pass diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 1f94694fd..0788c6df1 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,6 +25,7 @@ from skale import Skale from .types import IpRange +from .nftables import NFTablesController from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController @@ -101,3 +102,9 @@ def save_sync_ranges(sync_agent_ranges: List[IpRange], path: str) -> None: def ranges_from_plain_tuples(plain_ranges: List[Tuple]) -> List[IpRange]: return list(sorted(map(lambda r: IpRange(*r), plain_ranges))) + + +def cleanup_firewall_for_schain(schain_name: str) -> None: + nft = NFTablesController(chain=schain_name) + nft.cleanup() + nft.remove_saved_rules() diff --git a/tests/firewall/firewall_manager_test.py b/tests/firewall/firewall_manager_test.py index 719ad1bf0..04203accc 100644 --- a/tests/firewall/firewall_manager_test.py +++ b/tests/firewall/firewall_manager_test.py @@ -53,7 +53,7 @@ def test_firewall_manager_update_existed(): assert fm.host_controller.remove_rule.call_count == 0 -def test_firewall_manager_flush(): +def test_firewall_manager_cleanup(): fm = SChainTestFirewallManager('test', 10000, 10064) rules = [ SChainRule(10000, '2.2.2.2'), @@ -63,6 +63,6 @@ def test_firewall_manager_flush(): fm.add_rules(rules) fm.host_controller.add_rule(SChainRule(10072, '2.2.2.2')) - fm.flush() + fm.cleanup() assert list(fm.rules) == [] assert fm.host_controller.has_rule(SChainRule(10072, '2.2.2.2')) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index cee70eed5..e77af1fa6 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -7,6 +7,7 @@ from core.schains.firewall.nftables import NFTablesController, NFT_CHAIN_BASE_PATH from core.schains.firewall.types import SChainRule +from core.schains.firewall.utils import cleanup_firewall_for_schain from tools.helper import run_cmd @@ -87,6 +88,9 @@ def test_create_delete_chain(filter_table, nft_chain_folder): manager.cleanup() chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') assert chains == 'table inet firewall {\n}\n' + assert os.path.isfile(nft_chain_path) + + manager.remove_saved_rules() assert not os.path.isfile(nft_chain_path) @@ -106,8 +110,21 @@ def test_saved_rules(filter_table, nft_chain_folder): assert not os.path.isfile(nft_chain_path) +def test_cleanup_firewall_for_schain(filter_table, nft_chain_folder): + chain_name = 'test-chain' + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf') + + manager = NFTablesController(chain=chain_name) + manager.create_chain(first_port=10000, last_port=10063) + + cleanup_firewall_for_schain(schain_name=chain_name) + chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + assert chains == 'table inet firewall {\n}\n' + assert not os.path.isfile(nft_chain_path) + + def add_remove_rule(srule, refresh): - manager = NFTablesController() + manager = NFTablesController(chain='test') manager.add_rule(srule) time.sleep(1) if not manager.has_rule(srule): diff --git a/tests/schains/cleaner_test.py b/tests/schains/cleaner_test.py index d16b41fd3..ea45b25bd 100644 --- a/tests/schains/cleaner_test.py +++ b/tests/schains/cleaner_test.py @@ -239,7 +239,8 @@ def test_get_schains_on_node(schain_dirs_for_monitor, ]).issubset(set(result)) -def test_remove_schain(skale, schain_db, node_config, dutils): +@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain') +def test_remove_schain(cleanup_firewall_for_schain, skale, schain_db, node_config, dutils): schain_name = schain_db remove_schain(skale, node_config.id, schain_name, msg='Test remove_schain', dutils=dutils) container_name = SCHAIN_CONTAINER_NAME_TEMPLATE.format(schain_name) @@ -250,7 +251,9 @@ def test_remove_schain(skale, schain_db, node_config, dutils): assert record.is_deleted is True +@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain') def test_cleanup_schain( + cleanup_firewall_rules, schain_db, node_config, schain_on_contracts, diff --git a/tests/utils.py b/tests/utils.py index 06b7e5150..ac1f7f7d7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -233,6 +233,9 @@ class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self): return HostTestFirewallController() + def cleanup(self): + self.remove_rules(self.rules) + class SChainTestRuleController(SChainRuleController): def create_firewall_manager(self): @@ -248,6 +251,9 @@ def is_persistent(self) -> bool: def is_inited(self) -> bool: return True + def cleanup(self) -> None: + self.firewall_manager.cleanup() + def get_test_rule_controller( name, From ad35813e2125263fbb081192bf388cb48976378a Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 22 Jan 2025 17:10:22 +0000 Subject: [PATCH 29/30] Improve rule controller test --- core/schains/checks.py | 4 ++-- core/schains/firewall/firewall_manager.py | 2 ++ tests/firewall/default_rule_controller_test.py | 7 +++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/schains/checks.py b/core/schains/checks.py index 0a42bbdbb..e3b95799b 100644 --- a/core/schains/checks.py +++ b/core/schains/checks.py @@ -316,11 +316,11 @@ def firewall_rules(self) -> CheckRes: base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges ) logger.debug(f'Rule controller {self.rc.expected_rules()}') - data = { + data.update({ 'inited': self.rc.is_inited(), 'rules': self.rc.is_rules_synced(), 'persistent': self.rc.is_persistent(), - } + }) logger.debug('Firewall rules check: %s', data) status = all(data.values()) return CheckRes(status=status, data=data) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index f9d1bad2b..5ae6cbe29 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -105,6 +105,8 @@ def create_host_controller(self) -> NFTablesController: def rules_saved(self) -> bool: saved = self.host_controller.get_saved_rules() + if saved == '': + return False return saved == self.host_controller.get_plain_chain_rules() def base_config_applied(self) -> bool: diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index cadd5f5dd..ea211eab8 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -37,6 +37,9 @@ def test_get_default_rule_controller(nft_chain_folder): node_ips, sync_ip_range ) + + assert rc.is_inited() + assert rc.is_persistent() assert rc.actual_rules() == [] rc.sync() assert rc.expected_rules() == rc.actual_rules() @@ -51,6 +54,10 @@ def test_get_default_rule_controller(nft_chain_folder): assert hm.add_rule.call_count == 0 assert hm.remove_rule.call_count == 0 + rc.cleanup() + assert not rc.is_inited() + assert not rc.is_persistent() + def sync_rules(*args): rc = get_default_rule_controller(*args) From 4e6fd905cc60d5e031e4ee27f7779b541ad17c2f Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 22 Jan 2025 19:10:55 +0000 Subject: [PATCH 30/30] Fix tests --- core/schains/checks.py | 2 +- tests/schains/checks_test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/schains/checks.py b/core/schains/checks.py index e3b95799b..c0b5079f0 100644 --- a/core/schains/checks.py +++ b/core/schains/checks.py @@ -304,7 +304,7 @@ def firewall_rules(self) -> CheckRes: data = { 'inited': False, 'rules': False, - 'persistant': False, + 'persistent': False, } if self.config: conf = self.cfm.skaled_config diff --git a/tests/schains/checks_test.py b/tests/schains/checks_test.py index f0d67f32e..4ac148366 100644 --- a/tests/schains/checks_test.py +++ b/tests/schains/checks_test.py @@ -180,6 +180,8 @@ def test_volume_check(schain_checks, sample_false_checks, dutils): def test_firewall_rules_check(schain_checks, rules_unsynced_checks): schain_checks.rc.sync() + res = schain_checks.firewall_rules + print(res.data) assert schain_checks.firewall_rules assert not rules_unsynced_checks.firewall_rules.status