diff --git a/Dockerfile b/Dockerfile index 0a2c24bb..540b256f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM python:3.11-bookworm -RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig +RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig nftables python3-nftables RUN mkdir /usr/src/admin WORKDIR /usr/src/admin @@ -8,12 +8,13 @@ WORKDIR /usr/src/admin COPY requirements.txt ./ COPY requirements-dev.txt ./ -RUN pip3 install --no-cache-dir -r requirements.txt +RUN pip3 install -r requirements.txt COPY . . RUN update-alternatives --set iptables /usr/sbin/iptables-legacy && \ update-alternatives --set ip6tables /usr/sbin/ip6tables-legacy -ENV PYTHONPATH="/usr/src/admin" +ENV PYTHONPATH="/usr/src/admin":/usr/lib/python3/dist-packages/ + ENV COLUMNS=80 diff --git a/VERSION b/VERSION index 834f2629..c8e38b61 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.8.0 +2.9.0 diff --git a/core/schains/checks.py b/core/schains/checks.py index 8beb2673..c0b5079f 100644 --- a/core/schains/checks.py +++ b/core/schains/checks.py @@ -301,6 +301,11 @@ def volume(self) -> CheckRes: @property def firewall_rules(self) -> CheckRes: """Checks that firewall rules are set correctly""" + data = { + 'inited': False, + 'rules': False, + 'persistent': False, + } if self.config: conf = self.cfm.skaled_config base_port = get_base_port_from_config(conf) @@ -311,8 +316,15 @@ def firewall_rules(self) -> CheckRes: base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges ) logger.debug(f'Rule controller {self.rc.expected_rules()}') - return CheckRes(self.rc.is_rules_synced()) - return CheckRes(False) + data.update({ + 'inited': self.rc.is_inited(), + 'rules': self.rc.is_rules_synced(), + 'persistent': self.rc.is_persistent(), + }) + logger.debug('Firewall rules check: %s', data) + status = all(data.values()) + return CheckRes(status=status, data=data) + return CheckRes(status=False, data=data) @property def skaled_container(self) -> CheckRes: diff --git a/core/schains/cleaner.py b/core/schains/cleaner.py index 7fd291ef..881a2e63 100644 --- a/core/schains/cleaner.py +++ b/core/schains/cleaner.py @@ -17,10 +17,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import glob import logging import os import shutil from multiprocessing import Process +from pathlib import Path from typing import Optional from sgx import SgxClient @@ -28,22 +30,20 @@ from core.node import get_current_nodes, get_skale_node_version from core.schains.checks import SChainChecks -from core.schains.config.file_manager import ConfigFileManager from core.schains.config.directory import schain_config_dir from core.schains.dkg.utils import get_secret_key_share_filepath -from core.schains.firewall.utils import get_default_rule_controller -from core.schains.config.helper import ( - get_base_port_from_config, - get_node_ips_from_config, - get_own_ip_from_config, -) +from core.schains.firewall.utils import cleanup_firewall_for_schain, get_default_rule_controller from core.schains.process import ProcessReport, terminate_process from core.schains.runner import get_container_name, is_exited from core.schains.external_config import ExternalConfig from core.schains.types import ContainerType from core.schains.firewall.utils import get_sync_agent_ranges -from tools.configs import SGX_CERTIFICATES_FOLDER, SYNC_NODE +from tools.configs import ( + NFT_CHAIN_CONFIG_WILDCARD, + SGX_CERTIFICATES_FOLDER, + SYNC_NODE +) from tools.configs.schains import SCHAINS_DIR_PATH from tools.configs.containers import SCHAIN_CONTAINER, IMA_CONTAINER, SCHAIN_STOP_TIMEOUT from tools.docker_utils import DockerUtils @@ -136,18 +136,36 @@ def get_schains_with_containers(dutils=None): ] +def get_schains_firewall_configs() -> list: + return list(map(lambda path: Path(path).stem, glob.glob(NFT_CHAIN_CONFIG_WILDCARD))) + + def get_schains_on_node(dutils=None): dutils = dutils or DockerUtils() schains_with_dirs = os.listdir(SCHAINS_DIR_PATH) schains_with_container = get_schains_with_containers(dutils) schains_active_records = get_schains_names() + schains_firewall_configs = list( + map( + lambda name: name.removeprefix('skale-'), + get_schains_firewall_configs() + ) + ) logger.info( - 'dirs %s, containers: %s, records: %s', + 'dirs %s, containers: %s, records: %s, firewall configs: %s', schains_with_dirs, schains_with_container, - schains_active_records + schains_active_records, + schains_firewall_configs + ) + return sorted( + merged_unique( + schains_with_dirs, + schains_with_container, + schains_active_records, + schains_firewall_configs + ) ) - return sorted(merged_unique(schains_with_dirs, schains_with_container, schains_active_records)) def schain_names_to_ids(skale, schain_names): @@ -258,16 +276,10 @@ def cleanup_schain( remove_schain_container(schain_name, dutils=dutils) if check_status['volume']: remove_schain_volume(schain_name, dutils=dutils) - if check_status['firewall_rules']: - conf = ConfigFileManager(schain_name).skaled_config - base_port = get_base_port_from_config(conf) - own_ip = get_own_ip_from_config(conf) - node_ips = get_node_ips_from_config(conf) - ranges = [] - if estate is not None: - ranges = estate.ranges - rc.configure(base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges) - rc.cleanup() + if any(checks.firewall_rules.data): + logger.info('Cleaning firewall for %s', schain_name) + cleanup_firewall_for_schain(schain_name) + if estate is not None and estate.ima_linked: if check_status.get('ima_container', False) or is_exited( schain_name, container_type=ContainerType.ima, dutils=dutils diff --git a/core/schains/firewall/__init__.py b/core/schains/firewall/__init__.py index 8edbd1a7..1bba60b7 100644 --- a/core/schains/firewall/__init__.py +++ b/core/schains/firewall/__init__.py @@ -19,6 +19,7 @@ from .firewall_manager import SChainFirewallManager # noqa from .iptables import IptablesController # noqa +from .nftables import NFTablesController # noqa from .rule_controller import SChainRuleController # noqa from .types import IRuleController # noqa from .utils import get_default_rule_controller # noqa diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index e2216fc7..5ae6cbe2 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -22,6 +22,7 @@ from typing import Iterable, Optional from core.schains.firewall.iptables import IptablesController +from core.schains.firewall.nftables import NFTablesController from core.schains.firewall.types import ( IFirewallManager, IHostFirewallController, @@ -70,6 +71,11 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: rules_to_remove = actual_rules - expected_rules self.add_rules(rules_to_add) self.remove_rules(rules_to_remove) + self.save_rules() + + def save_rules(self) -> None: + """ Saves rules into persistent storage """ + self.host_controller.save_rules() def add_rules(self, rules: Iterable[SChainRule]) -> None: logger.debug('Adding rules %s', rules) @@ -81,10 +87,32 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None: for rule in rules: self.host_controller.remove_rule(rule) - def flush(self) -> None: - self.remove_rules(self.rules) - class IptablesSChainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> IptablesController: return IptablesController() + + def cleanup(self) -> None: + self.remove_rules(self.rules) + + +class NFTSchainFirewallManager(SChainFirewallManager): + def create_host_controller(self) -> NFTablesController: + nc_controller = NFTablesController(chain=self.name) + nc_controller.create_table() + nc_controller.create_chain(self.first_port, self.last_port) + return nc_controller + + def rules_saved(self) -> bool: + saved = self.host_controller.get_saved_rules() + if saved == '': + return False + return saved == self.host_controller.get_plain_chain_rules() + + def base_config_applied(self) -> bool: + return self.host_controller.has_chain(self.host_controller.chain) and \ + self.host_controller.has_drop_rule(self.first_port, self.last_port) + + def cleanup(self) -> None: + self.host_controller.cleanup() + self.host_controller.remove_saved_rules() diff --git a/core/schains/firewall/iptables.py b/core/schains/firewall/iptables.py index 1d28c403..589250d6 100644 --- a/core/schains/firewall/iptables.py +++ b/core/schains/firewall/iptables.py @@ -139,3 +139,9 @@ def from_ip_network(cls, ip: str) -> str: @classmethod def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) + + def save_rules(self) -> None: + raise NotImplementedError('save_rules is not implemented for iptables host controller') + + def cleanup(self) -> None: + raise NotImplementedError('cleanup is not implemented for iptables host controller') diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py new file mode 100644 index 00000000..e2a56e39 --- /dev/null +++ b/core/schains/firewall/nftables.py @@ -0,0 +1,391 @@ +# -*- coding: utf-8 -*- +# This file is part of SKALE Admin +# +# Copyright (C) 2024 SKALE Labs +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +import importlib +import json +import logging +import multiprocessing +import os +from typing import Iterable + +from core.schains.firewall.types import IHostFirewallController, SChainRule + +from tools.configs import NFT_CHAIN_BASE_PATH + + +logger = logging.getLogger(__name__) + + +TABLE = 'firewall' +CHAIN = 'skale' + + +class NFTablesCmdFailedError(Exception): + pass + + +class NFTablesController(IHostFirewallController): + plock = multiprocessing.Lock() + FAMILY = 'inet' + + def __init__(self, chain: str, table: str = TABLE) -> None: + self.table = table + self.chain = f'skale-{chain}' + self._nftables = importlib.import_module('nftables') + self.nft = self._nftables.Nftables() + self.nft.set_json_output(True) + self.nft.set_stateless_output(True) + + @classmethod + def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: + expr = [] + + if rule.first_ip: + if rule.last_ip == rule.first_ip: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'right': f'{rule.first_ip}', + } + } + ) + else: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, + } + } + ) + + if rule.port: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': rule.port, + } + } + ) + + if counter: + expr.append({'counter': None}) + + expr.append({'accept': None}) + return expr + + @classmethod + def expr_to_rule(self, expr: list) -> None: + port, first_ip, last_ip = None, None, None + for item in expr: + if 'match' in item: + match = item['match'] + + if match.get('left', {}).get('payload', {}).get('field') == 'dport': + port = match.get('right') + + if match.get('left', {}).get('payload', {}).get('field') == 'saddr': + right = match.get('right') + if isinstance(right, str): + first_ip = right + else: + first_ip, last_ip = right['range'] + + if any([port, first_ip, last_ip]): + return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) + + def _compose_json(self, commands: list[dict]) -> dict: + json_cmd = {'nftables': commands} + self.nft.json_validate(json_cmd) + return json_cmd + + def create_table(self) -> None: + if not self.has_table(self.table): + return self.run_cmd(f'add table inet {self.table}') + + def has_drop_rule(self, first_port: int, last_port: int) -> bool: + expr = [ + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': {'range': [first_port, last_port]}, + } + }, + {'counter': None}, + {'drop': None}, + ] + + return self.expr_to_rule(expr) in self.get_rules_by_policy(policy='drop') + + def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: + if not self.has_drop_rule(first_port, last_port): + expr = [ + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': {'range': [first_port, last_port]}, + } + }, + {'counter': None}, + {'drop': None}, + ] + + cmd = { + 'nftables': [ + { + 'add': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + } + self.run_json_cmd(cmd) + logger.info('Added drop rule for chain %s', self.chain) + + def create_chain(self, first_port: int, last_port: int) -> None: + if not self.has_chain(self.chain): + logger.info('Creating chain %s', self.chain) + self.run_json_cmd( + self._compose_json( + [ + { + 'add': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain, + 'hook': 'input', + 'type': 'filter', + 'prio': 0, + 'policy': 'accept', + } + } + } + ] + ) + ) + self.add_schain_drop_rule(first_port, last_port) + self.save_rules() + + def delete_chain(self) -> None: + if self.has_chain(self.chain): + logger.info('Removing chain %s', self.chain) + self.run_json_cmd( + self._compose_json( + [ + { + 'delete': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain + } + } + } + ] + ) + ) + + @property + def chains(self) -> list[dict]: + output = self.run_cmd('list chains') + if output[0] != 0: + raise NFTablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['chain']['name'] for record in parsed if 'chain' in record] + + @property + def tables(self) -> list[dict]: + output = self.run_cmd('list tables') + if output[0] != 0: + raise NFTablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['table']['name'] for record in parsed if 'table' in record] + + def run_json_cmd(self, cmd: dict) -> tuple: + logger.debug('NFTables json cmd %s', cmd) + with self.plock: + return self.nft.json_cmd(cmd) + + def run_cmd(self, cmd: str) -> tuple: + logger.debug('NFTables cmd %s', cmd) + with self.plock: + return self.nft.cmd(cmd) + + def has_chain(self, chain: str) -> bool: + return chain in self.chains + + def has_table(self, table: str) -> bool: + return table in self.tables + + def add_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + return + expr = self.rule_to_expr(rule) + + json_cmd = self._compose_json( + [ + { + 'insert': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + ) + + rc, _, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') + + def remove_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + expr = self.rule_to_expr(rule, counter=False) + + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if rc != 0: + raise NFTablesCmdFailedError(f'Failed to list rules: {error}') + + current_rules = json.loads(output) + + handle = None + for item in current_rules.get('nftables', []): + if 'rule' in item: + rule_data = item['rule'] + rule_expr = list( + filter(lambda statement: 'counter' not in statement, rule_data['expr']) + ) + if expr == rule_expr: + handle = rule_data.get('handle') + break + + if handle is None: + raise NFTablesCmdFailedError('Rule not found') + + json_cmd = self._compose_json( + [ + { + 'delete': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'handle': handle, + } + } + } + ] + ) + + rc, _, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NFTablesCmdFailedError(f'Failed to delete rule: {error}') + + @property # type: ignore + def rules(self) -> Iterable[SChainRule]: + return self.get_rules_by_policy(policy='accept') + + def has_rule(self, rule: SChainRule) -> bool: + return rule in self.rules + + def get_rules_by_policy(self, policy: str) -> list[SChainRule]: + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if output == '': + return [] + + data = json.loads(output) + rules = [] + + for item in data.get('nftables', []): + if 'rule' in item: + plain_rule = item['rule'] + expr = plain_rule.get('expr', []) + if {policy: None} in expr: + rule = self.expr_to_rule(expr) + if rule: + rules.append(rule) + logger.debug('Rules for policy %s: %s', policy, rules) + return rules + + def get_plain_chain_rules(self) -> str: + self.nft.set_json_output(False) + output = '' + try: + rc, output, error = self.run_cmd( + f'list chain {self.FAMILY} {self.table} {self.chain}' + ) + if rc != 0: + raise NFTablesCmdFailedError(f"Failed to get table content: {error}") + finally: + self.nft.set_json_output(True) + + lines = output.split('\n') + # cleanup table header + if lines[-1] == '': + lines = lines[1:-2] + else: + lines = lines[1:-1] + + # remove leading tab + lines = list(map(lambda line: line[1:], lines)) + # Adding new line at the end to prevent validation failure + lines.append('') + output = '\n'.join(lines) + return output + + @property + def nft_chain_path(self) -> str: + return os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') + + def save_rules(self) -> None: + logger.info('Saving the firewall rules for chain %s', self.chain) + chain_rules = self.get_plain_chain_rules() + with open(self.nft_chain_path, 'w') as nft_chain_file: + nft_chain_file.write(chain_rules) + + def get_saved_rules(self) -> str: + if not os.path.isfile(self.nft_chain_path): + return '' + with open(self.nft_chain_path, 'r') as nft_chain_file: + return nft_chain_file.read() + + def remove_saved_rules(self) -> None: + if os.path.isfile(self.nft_chain_path): + os.remove(self.nft_chain_path) + + def cleanup(self) -> None: + self.delete_chain() diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index 51e8920a..68620526 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -23,7 +23,7 @@ from functools import wraps from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar -from .firewall_manager import IptablesSChainFirewallManager +from .firewall_manager import IptablesSChainFirewallManager, NFTSchainFirewallManager from .types import ( IFirewallManager, IpRange, @@ -202,9 +202,6 @@ def sync(self) -> None: logger.debug('Syncing firewall rules with %s', erules) self.firewall_manager.update_rules(erules) - def cleanup(self) -> None: - self.firewall_manager.flush() - class IptablesSChainRuleController(SChainRuleController): @configured_only @@ -214,3 +211,36 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager: self.base_port, # type: ignore self.base_port + self.ports_per_schain - 1 # type: ignore ) + + @configured_only + def is_persistent(self) -> bool: + return True + + @configured_only + def is_inited(self) -> bool: + return True + + @configured_only + def cleanup(self) -> None: + self.firewall_manager.cleanup() + + +class NFTSchainRuleController(SChainRuleController): + @configured_only + def create_firewall_manager(self) -> NFTSchainFirewallManager: + return NFTSchainFirewallManager( + self.name, + self.base_port, # type: ignore + self.base_port + self.ports_per_schain - 1 # type: ignore + ) + + @configured_only + def is_persistent(self) -> bool: + return self.firewall_manager.rules_saved() + + @configured_only + def is_inited(self) -> bool: + return self.firewall_manager.base_config_applied() + + def cleanup(self) -> None: + self.firewall_manager.cleanup() diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index 65ba8885..c30bfc11 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -88,6 +88,14 @@ def rules(self) -> Iterable[SChainRule]: # pragma: no cover def has_rule(self, rule: SChainRule) -> bool: # pragma: no cover pass + @abstractmethod + def save_rules(self) -> None: # pragma: no cover + pass + + @abstractmethod + def cleanup(self) -> None: # pragma: no cover + pass + class IFirewallManager(ABC): @property @@ -100,7 +108,7 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: # pragma: no cover pass @abstractmethod - def flush(self) -> None: # pragma: no cover # noqa + def cleanup(self) -> None: # pragma: no cover # noqa pass @@ -131,3 +139,11 @@ def sync(self) -> None: # pragma: no cover @abstractmethod def cleanup(self) -> None: # pragma: no cover pass + + @abstractmethod + def is_persistent(self) -> bool: # pragma: no cover + pass + + @abstractmethod + def is_inited(self) -> bool: # pragma: no cover + pass diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 737361e1..0788c6df 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,7 +25,8 @@ from skale import Skale from .types import IpRange -from .rule_controller import IptablesSChainRuleController +from .nftables import NFTablesController +from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController logger = logging.getLogger(__name__) @@ -37,6 +38,22 @@ def get_default_rule_controller( own_ip: Optional[str] = None, node_ips: List[str] = [], sync_agent_ranges: Optional[List[IpRange]] = [] +) -> IptablesSChainRuleController: + return get_nftables_rule_controller( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_agent_ranges=sync_agent_ranges + ) + + +def get_iptables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] ) -> IptablesSChainRuleController: sync_agent_ranges = sync_agent_ranges or [] logger.info('Creating rule controller for %s', name) @@ -50,6 +67,25 @@ def get_default_rule_controller( ) +def get_nftables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] +) -> NFTSchainRuleController: + sync_agent_ranges = sync_agent_ranges or [] + logger.info('Creating rule controller for %s', name) + logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges) + return NFTSchainRuleController( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_ip_ranges=sync_agent_ranges + ) + + def get_sync_agent_ranges(skale: Skale) -> List[IpRange]: sync_agent_ranges = [] rnum = skale.sync_manager.get_ip_ranges_number() @@ -66,3 +102,9 @@ def save_sync_ranges(sync_agent_ranges: List[IpRange], path: str) -> None: def ranges_from_plain_tuples(plain_ranges: List[Tuple]) -> List[IpRange]: return list(sorted(map(lambda r: IpRange(*r), plain_ranges))) + + +def cleanup_firewall_for_schain(schain_name: str) -> None: + nft = NFTablesController(chain=schain_name) + nft.cleanup() + nft.remove_saved_rules() diff --git a/tests.Dockerfile b/tests.Dockerfile index b31db00e..75b72f63 100644 --- a/tests.Dockerfile +++ b/tests.Dockerfile @@ -1,3 +1,7 @@ FROM admin:base -RUN pip3 install --no-cache-dir -r requirements-dev.txt +RUN apt update && apt install -y nftables python3-nftables + +RUN pip3 install -r requirements-dev.txt + +ENV PYTHONPATH=${PYTHONPATH}:/usr/lib/python3/dist-packages/ diff --git a/tests/conftest.py b/tests/conftest.py index 973a375e..ac4b92ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -631,3 +631,13 @@ def ncli_status(_schain_name): yield init_node_cli_status(_schain_name) finally: shutil.rmtree(schain_dir_path, ignore_errors=True) + + +@pytest.fixture() +def nft_chain_folder(): + path = '/etc/nft.conf.d/skale/chains' + try: + os.makedirs(path) + yield path + finally: + shutil.rmtree(path) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index c2473e16..ea211eab 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -1,26 +1,28 @@ + +import concurrent.futures import mock + import pytest -import concurrent.futures from skale.schain_config import PORTS_PER_SCHAIN # noqa -from core.schains.firewall import IptablesController +from core.schains.firewall import NFTablesController from core.schains.firewall.utils import get_default_rule_controller from core.schains.firewall.types import IpRange, SkaledPorts -from tests.firewall.iptables_test import get_rules_through_subprocess + from tools.helper import run_cmd @pytest.fixture def refresh(): - run_cmd(['iptables', '-F']) + run_cmd(['nft', 'flush', 'ruleset']) try: yield finally: - run_cmd(['iptables', '-F']) + run_cmd(['nft', 'flush', 'ruleset']) -def test_get_default_rule_controller(): +def test_get_default_rule_controller(nft_chain_folder): own_ip = '3.3.3.3' node_ips = ['1.1.1.1', '2.2.2.2', '3.3.3.3', '4.4.4.4'] base_port = 10064 @@ -35,6 +37,9 @@ def test_get_default_rule_controller(): node_ips, sync_ip_range ) + + assert rc.is_inited() + assert rc.is_persistent() assert rc.actual_rules() == [] rc.sync() assert rc.expected_rules() == rc.actual_rules() @@ -49,6 +54,10 @@ def test_get_default_rule_controller(): assert hm.add_rule.call_count == 0 assert hm.remove_rule.call_count == 0 + rc.cleanup() + assert not rc.is_inited() + assert not rc.is_persistent() + def sync_rules(*args): rc = get_default_rule_controller(*args) @@ -58,22 +67,6 @@ def sync_rules(*args): return False -def parse_plain_rule(plain_rule): - first_ip, last_ip, port = None, None, None - pr = plain_rule.split() - if '--src-range' in pr: - srange = pr[11] - first_ip, last_ip = srange.split('-') - port = pr[7] - elif '-s' in pr: - first_ip = last_ip = pr[3][:-3] - port = pr[9] - elif '--dport' in pr: - port = pr[7] - - return first_ip, last_ip, int(port) - - def run_concurrent_rc_syncing( node_number, schain_number, @@ -147,59 +140,50 @@ def run_concurrent_rc_syncing( else: assert not r - pr = get_rules_through_subprocess(unique=False)[3:] - rules = [parse_plain_rule(r) for r in pr] + controllers = [NFTablesController(chain=name) for name in schain_names] + rules = [] + for controller in controllers: + rules.extend(controller.rules) + + print([r.port for r in rules]) + print([r.first_ip for r in rules]) - c = IptablesController() # Check that all ip rules are there for ip in node_ips: if ip != own_ip: assert sum( - map(lambda x: x[0] == ip, rules) - ) == 5 * schain_number, ip - assert sum( - map(lambda x: x.first_ip == ip, c.rules) + map(lambda x: x.first_ip == ip, rules) ) == 5 * schain_number, ip # Check that all internal ports rules are there except CATCHUP for p in internal_ports: - assert sum(map(lambda x: x[2] == p, rules)) == node_number - 1, p - assert sum(map(lambda x: x.port == p, c.rules)) == node_number - 1, p + assert sum(map(lambda x: x.port == p, rules)) == node_number - 1, p # Check CATCHUP rules including sync agents rules catchup_e_number = node_number + sync_agent_ranges_number - 1 for p in catchup_ports: - assert sum(map(lambda x: x[2] == p, rules)) == catchup_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == catchup_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == catchup_e_number, p # Check ZMQ rules including sync agents rules zmq_e_number = node_number + sync_agent_ranges_number - 1 for p in zmq_ports: - assert sum(map(lambda x: x[2] == p, rules)) == zmq_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == zmq_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == zmq_e_number, p # Check sync ip ranges rules for r in sync_agent_ranges: assert sum( - map(lambda x: x[0] == r.start_ip, rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x.first_ip == r.start_ip, c.rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x[1] == r.end_ip, rules) + map(lambda x: x.first_ip == r.start_ip, rules) ) == schain_number * 2, ip assert sum( - map(lambda x: x.last_ip == r.end_ip, c.rules) + map(lambda x: x.last_ip == r.end_ip, rules) ) == schain_number * 2, ip for port in public_ports: - assert sum(map(lambda x: x[2] == port, rules)) == 1, port - assert sum(map(lambda x: x.port == port, c.rules)) == 1, port + assert sum(map(lambda x: x.port == port, rules)) == 1, port @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_no_refresh(attempt): +def test_concurrent_rc_behavior_no_refresh(attempt, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' @@ -215,7 +199,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt, refresh): +def test_concurrent_rc_behavior_with_refresh(attempt, refresh, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' diff --git a/tests/firewall/firewall_manager_test.py b/tests/firewall/firewall_manager_test.py index 719ad1bf..04203acc 100644 --- a/tests/firewall/firewall_manager_test.py +++ b/tests/firewall/firewall_manager_test.py @@ -53,7 +53,7 @@ def test_firewall_manager_update_existed(): assert fm.host_controller.remove_rule.call_count == 0 -def test_firewall_manager_flush(): +def test_firewall_manager_cleanup(): fm = SChainTestFirewallManager('test', 10000, 10064) rules = [ SChainRule(10000, '2.2.2.2'), @@ -63,6 +63,6 @@ def test_firewall_manager_flush(): fm.add_rules(rules) fm.host_controller.add_rule(SChainRule(10072, '2.2.2.2')) - fm.flush() + fm.cleanup() assert list(fm.rules) == [] assert fm.host_controller.has_rule(SChainRule(10072, '2.2.2.2')) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py new file mode 100644 index 00000000..e77af1fa --- /dev/null +++ b/tests/firewall/nftables_test.py @@ -0,0 +1,161 @@ +import concurrent.futures +import importlib +import os +import time + +import pytest + +from core.schains.firewall.nftables import NFTablesController, NFT_CHAIN_BASE_PATH +from core.schains.firewall.types import SChainRule +from core.schains.firewall.utils import cleanup_firewall_for_schain +from tools.helper import run_cmd + + +@pytest.fixture +def nf_test_tables(): + nft = importlib.import_module('nftables').Nftables() + nft.cmd('flush ruleset') + return nft + + +@pytest.fixture +def filter_table(nf_test_tables): + print(nf_test_tables.cmd('add table inet firewall')) + + +@pytest.fixture +def custom_chain(nf_test_tables, filter_table): + name = 'test-chain' + nf_test_tables.cmd(f'add chain inet firewall skale-{name}') + return name + + +def test_nftables_controller(custom_chain): + nft_controller = NFTablesController(chain='test-chain') + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + rule_b = SChainRule(10001, '3.3.3.3') + nft_controller.add_rule(rule_a) + nft_controller.add_rule(rule_b) + assert nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + rules = list(nft_controller.rules) + assert sorted(rules) == sorted([rule_b, rule_a]), (rules, sorted([rule_b, rule_a])) + nft_controller.remove_rule(rule_a) + assert not nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + nft_controller.remove_rule(rule_b) + assert not nft_controller.has_rule(rule_a) + + +def test_nftables_controller_duplicates(custom_chain): + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + manager = NFTablesController(chain='test-chain') + manager.add_rule(rule_a) + rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4') + manager.add_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + assert manager.has_rule(rule_b) + manager.add_rule(rule_b) + assert manager.has_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + manager.remove_rule(rule_b) + assert list(manager.rules) == [ + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ] + + +def test_create_delete_chain(filter_table, nft_chain_folder): + chain_name = 'test-chain' + + output = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + output == 'table inet firewall {\n}\n' + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf') + assert not os.path.isfile(nft_chain_path) + + manager = NFTablesController(chain=chain_name) + manager.create_chain(first_port=10000, last_port=10063) + + chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + assert chains == 'table inet firewall {\n\tchain skale-test-chain {\n\t\ttype filter hook input priority filter; policy accept;\n\t}\n}\n' # noqa + assert os.path.isfile(nft_chain_path) + + manager.cleanup() + chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + assert chains == 'table inet firewall {\n}\n' + assert os.path.isfile(nft_chain_path) + + manager.remove_saved_rules() + assert not os.path.isfile(nft_chain_path) + + +def test_saved_rules(filter_table, nft_chain_folder): + chain_name = 'test-chain' + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf') + + manager = NFTablesController(chain=chain_name) + assert not os.path.isfile(nft_chain_path) + manager.create_chain(first_port=10000, last_port=10063) + assert os.path.isfile(nft_chain_path) + assert manager.get_saved_rules() == 'chain skale-test-chain {\n\ttype filter hook input priority filter; policy accept;\n\ttcp dport 10000-10063 counter drop\n}\n' # noqa + + assert os.path.isfile(nft_chain_path) + + manager.remove_saved_rules() + assert not os.path.isfile(nft_chain_path) + + +def test_cleanup_firewall_for_schain(filter_table, nft_chain_folder): + chain_name = 'test-chain' + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf') + + manager = NFTablesController(chain=chain_name) + manager.create_chain(first_port=10000, last_port=10063) + + cleanup_firewall_for_schain(schain_name=chain_name) + chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + assert chains == 'table inet firewall {\n}\n' + assert not os.path.isfile(nft_chain_path) + + +def add_remove_rule(srule, refresh): + manager = NFTablesController(chain='test') + manager.add_rule(srule) + time.sleep(1) + if not manager.has_rule(srule): + return False + time.sleep(1) + manager.remove_rule(srule) + return True + + +def generate_srules(number=5): + return [ + SChainRule( + 10000 + 1, + f'{i}.{i}.{i}.{i}', f'{i + 1}.{i + 1}.{i + 1}.{i + 1}' + ) + for i in range(1, number * 2, 2) + ] + + +def test_nftables_manager_parallel(custom_chain): + srules = generate_srules(number=12) + + futures = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=12) as executor: + futures = [ + executor.submit(add_remove_rule, srule) + for srule in srules + ] + + for future in concurrent.futures.as_completed(futures): + assert future.result + manager = NFTablesController(custom_chain) + time.sleep(10) + assert len(list(manager.rules)) == 0 diff --git a/tests/schains/checks_test.py b/tests/schains/checks_test.py index f0d67f32..4ac14836 100644 --- a/tests/schains/checks_test.py +++ b/tests/schains/checks_test.py @@ -180,6 +180,8 @@ def test_volume_check(schain_checks, sample_false_checks, dutils): def test_firewall_rules_check(schain_checks, rules_unsynced_checks): schain_checks.rc.sync() + res = schain_checks.firewall_rules + print(res.data) assert schain_checks.firewall_rules assert not rules_unsynced_checks.firewall_rules.status diff --git a/tests/schains/cleaner_test.py b/tests/schains/cleaner_test.py index d16b41fd..ea45b25b 100644 --- a/tests/schains/cleaner_test.py +++ b/tests/schains/cleaner_test.py @@ -239,7 +239,8 @@ def test_get_schains_on_node(schain_dirs_for_monitor, ]).issubset(set(result)) -def test_remove_schain(skale, schain_db, node_config, dutils): +@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain') +def test_remove_schain(cleanup_firewall_for_schain, skale, schain_db, node_config, dutils): schain_name = schain_db remove_schain(skale, node_config.id, schain_name, msg='Test remove_schain', dutils=dutils) container_name = SCHAIN_CONTAINER_NAME_TEMPLATE.format(schain_name) @@ -250,7 +251,9 @@ def test_remove_schain(skale, schain_db, node_config, dutils): assert record.is_deleted is True +@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain') def test_cleanup_schain( + cleanup_firewall_rules, schain_db, node_config, schain_on_contracts, diff --git a/tests/utils.py b/tests/utils.py index dc33bf91..ac1f7f7d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -222,11 +222,20 @@ def rules(self): def has_rule(self, srule): return srule in self._rules + def save_rules(self): + pass + + def cleanup(self): + pass + class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self): return HostTestFirewallController() + def cleanup(self): + self.remove_rules(self.rules) + class SChainTestRuleController(SChainRuleController): def create_firewall_manager(self): @@ -236,6 +245,15 @@ def create_firewall_manager(self): self.base_port + self.ports_per_schain ) + def is_persistent(self) -> bool: + return True + + def is_inited(self) -> bool: + return True + + def cleanup(self) -> None: + self.firewall_manager.cleanup() + def get_test_rule_controller( name, diff --git a/tools/configs/__init__.py b/tools/configs/__init__.py index 4794de04..1e2fd423 100644 --- a/tools/configs/__init__.py +++ b/tools/configs/__init__.py @@ -106,3 +106,6 @@ SYNC_NODE = os.getenv('SYNC_NODE') == 'True' DOCKER_NODE_CONFIG_FILEPATH = os.path.join(NODE_DATA_PATH, 'docker.json') + +NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/skale/chains' +NFT_CHAIN_CONFIG_WILDCARD = os.path.join(NFT_CHAIN_BASE_PATH, '*')