diff --git a/Dockerfile b/Dockerfile
index 0a2c24bb..540b256f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,6 @@
FROM python:3.11-bookworm
-RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig
+RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig nftables python3-nftables
RUN mkdir /usr/src/admin
WORKDIR /usr/src/admin
@@ -8,12 +8,13 @@ WORKDIR /usr/src/admin
COPY requirements.txt ./
COPY requirements-dev.txt ./
-RUN pip3 install --no-cache-dir -r requirements.txt
+RUN pip3 install -r requirements.txt
COPY . .
RUN update-alternatives --set iptables /usr/sbin/iptables-legacy && \
update-alternatives --set ip6tables /usr/sbin/ip6tables-legacy
-ENV PYTHONPATH="/usr/src/admin"
+ENV PYTHONPATH="/usr/src/admin":/usr/lib/python3/dist-packages/
+
ENV COLUMNS=80
diff --git a/VERSION b/VERSION
index 834f2629..c8e38b61 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-2.8.0
+2.9.0
diff --git a/core/schains/checks.py b/core/schains/checks.py
index 8beb2673..c0b5079f 100644
--- a/core/schains/checks.py
+++ b/core/schains/checks.py
@@ -301,6 +301,11 @@ def volume(self) -> CheckRes:
@property
def firewall_rules(self) -> CheckRes:
"""Checks that firewall rules are set correctly"""
+ data = {
+ 'inited': False,
+ 'rules': False,
+ 'persistent': False,
+ }
if self.config:
conf = self.cfm.skaled_config
base_port = get_base_port_from_config(conf)
@@ -311,8 +316,15 @@ def firewall_rules(self) -> CheckRes:
base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges
)
logger.debug(f'Rule controller {self.rc.expected_rules()}')
- return CheckRes(self.rc.is_rules_synced())
- return CheckRes(False)
+ data.update({
+ 'inited': self.rc.is_inited(),
+ 'rules': self.rc.is_rules_synced(),
+ 'persistent': self.rc.is_persistent(),
+ })
+ logger.debug('Firewall rules check: %s', data)
+ status = all(data.values())
+ return CheckRes(status=status, data=data)
+ return CheckRes(status=False, data=data)
@property
def skaled_container(self) -> CheckRes:
diff --git a/core/schains/cleaner.py b/core/schains/cleaner.py
index 7fd291ef..881a2e63 100644
--- a/core/schains/cleaner.py
+++ b/core/schains/cleaner.py
@@ -17,10 +17,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+import glob
import logging
import os
import shutil
from multiprocessing import Process
+from pathlib import Path
from typing import Optional
from sgx import SgxClient
@@ -28,22 +30,20 @@
from core.node import get_current_nodes, get_skale_node_version
from core.schains.checks import SChainChecks
-from core.schains.config.file_manager import ConfigFileManager
from core.schains.config.directory import schain_config_dir
from core.schains.dkg.utils import get_secret_key_share_filepath
-from core.schains.firewall.utils import get_default_rule_controller
-from core.schains.config.helper import (
- get_base_port_from_config,
- get_node_ips_from_config,
- get_own_ip_from_config,
-)
+from core.schains.firewall.utils import cleanup_firewall_for_schain, get_default_rule_controller
from core.schains.process import ProcessReport, terminate_process
from core.schains.runner import get_container_name, is_exited
from core.schains.external_config import ExternalConfig
from core.schains.types import ContainerType
from core.schains.firewall.utils import get_sync_agent_ranges
-from tools.configs import SGX_CERTIFICATES_FOLDER, SYNC_NODE
+from tools.configs import (
+ NFT_CHAIN_CONFIG_WILDCARD,
+ SGX_CERTIFICATES_FOLDER,
+ SYNC_NODE
+)
from tools.configs.schains import SCHAINS_DIR_PATH
from tools.configs.containers import SCHAIN_CONTAINER, IMA_CONTAINER, SCHAIN_STOP_TIMEOUT
from tools.docker_utils import DockerUtils
@@ -136,18 +136,36 @@ def get_schains_with_containers(dutils=None):
]
+def get_schains_firewall_configs() -> list:
+ return list(map(lambda path: Path(path).stem, glob.glob(NFT_CHAIN_CONFIG_WILDCARD)))
+
+
def get_schains_on_node(dutils=None):
dutils = dutils or DockerUtils()
schains_with_dirs = os.listdir(SCHAINS_DIR_PATH)
schains_with_container = get_schains_with_containers(dutils)
schains_active_records = get_schains_names()
+ schains_firewall_configs = list(
+ map(
+ lambda name: name.removeprefix('skale-'),
+ get_schains_firewall_configs()
+ )
+ )
logger.info(
- 'dirs %s, containers: %s, records: %s',
+ 'dirs %s, containers: %s, records: %s, firewall configs: %s',
schains_with_dirs,
schains_with_container,
- schains_active_records
+ schains_active_records,
+ schains_firewall_configs
+ )
+ return sorted(
+ merged_unique(
+ schains_with_dirs,
+ schains_with_container,
+ schains_active_records,
+ schains_firewall_configs
+ )
)
- return sorted(merged_unique(schains_with_dirs, schains_with_container, schains_active_records))
def schain_names_to_ids(skale, schain_names):
@@ -258,16 +276,10 @@ def cleanup_schain(
remove_schain_container(schain_name, dutils=dutils)
if check_status['volume']:
remove_schain_volume(schain_name, dutils=dutils)
- if check_status['firewall_rules']:
- conf = ConfigFileManager(schain_name).skaled_config
- base_port = get_base_port_from_config(conf)
- own_ip = get_own_ip_from_config(conf)
- node_ips = get_node_ips_from_config(conf)
- ranges = []
- if estate is not None:
- ranges = estate.ranges
- rc.configure(base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges)
- rc.cleanup()
+ if any(checks.firewall_rules.data):
+ logger.info('Cleaning firewall for %s', schain_name)
+ cleanup_firewall_for_schain(schain_name)
+
if estate is not None and estate.ima_linked:
if check_status.get('ima_container', False) or is_exited(
schain_name, container_type=ContainerType.ima, dutils=dutils
diff --git a/core/schains/firewall/__init__.py b/core/schains/firewall/__init__.py
index 8edbd1a7..1bba60b7 100644
--- a/core/schains/firewall/__init__.py
+++ b/core/schains/firewall/__init__.py
@@ -19,6 +19,7 @@
from .firewall_manager import SChainFirewallManager # noqa
from .iptables import IptablesController # noqa
+from .nftables import NFTablesController # noqa
from .rule_controller import SChainRuleController # noqa
from .types import IRuleController # noqa
from .utils import get_default_rule_controller # noqa
diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py
index e2216fc7..5ae6cbe2 100644
--- a/core/schains/firewall/firewall_manager.py
+++ b/core/schains/firewall/firewall_manager.py
@@ -22,6 +22,7 @@
from typing import Iterable, Optional
from core.schains.firewall.iptables import IptablesController
+from core.schains.firewall.nftables import NFTablesController
from core.schains.firewall.types import (
IFirewallManager,
IHostFirewallController,
@@ -70,6 +71,11 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None:
rules_to_remove = actual_rules - expected_rules
self.add_rules(rules_to_add)
self.remove_rules(rules_to_remove)
+ self.save_rules()
+
+ def save_rules(self) -> None:
+ """ Saves rules into persistent storage """
+ self.host_controller.save_rules()
def add_rules(self, rules: Iterable[SChainRule]) -> None:
logger.debug('Adding rules %s', rules)
@@ -81,10 +87,32 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None:
for rule in rules:
self.host_controller.remove_rule(rule)
- def flush(self) -> None:
- self.remove_rules(self.rules)
-
class IptablesSChainFirewallManager(SChainFirewallManager):
def create_host_controller(self) -> IptablesController:
return IptablesController()
+
+ def cleanup(self) -> None:
+ self.remove_rules(self.rules)
+
+
+class NFTSchainFirewallManager(SChainFirewallManager):
+ def create_host_controller(self) -> NFTablesController:
+ nc_controller = NFTablesController(chain=self.name)
+ nc_controller.create_table()
+ nc_controller.create_chain(self.first_port, self.last_port)
+ return nc_controller
+
+ def rules_saved(self) -> bool:
+ saved = self.host_controller.get_saved_rules()
+ if saved == '':
+ return False
+ return saved == self.host_controller.get_plain_chain_rules()
+
+ def base_config_applied(self) -> bool:
+ return self.host_controller.has_chain(self.host_controller.chain) and \
+ self.host_controller.has_drop_rule(self.first_port, self.last_port)
+
+ def cleanup(self) -> None:
+ self.host_controller.cleanup()
+ self.host_controller.remove_saved_rules()
diff --git a/core/schains/firewall/iptables.py b/core/schains/firewall/iptables.py
index 1d28c403..589250d6 100644
--- a/core/schains/firewall/iptables.py
+++ b/core/schains/firewall/iptables.py
@@ -139,3 +139,9 @@ def from_ip_network(cls, ip: str) -> str:
@classmethod
def to_ip_network(cls, ip: str) -> str:
return str(ipaddress.ip_network(ip))
+
+ def save_rules(self) -> None:
+ raise NotImplementedError('save_rules is not implemented for iptables host controller')
+
+ def cleanup(self) -> None:
+ raise NotImplementedError('cleanup is not implemented for iptables host controller')
diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py
new file mode 100644
index 00000000..e2a56e39
--- /dev/null
+++ b/core/schains/firewall/nftables.py
@@ -0,0 +1,391 @@
+# -*- coding: utf-8 -*-
+# This file is part of SKALE Admin
+#
+# Copyright (C) 2024 SKALE Labs
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+
+import importlib
+import json
+import logging
+import multiprocessing
+import os
+from typing import Iterable
+
+from core.schains.firewall.types import IHostFirewallController, SChainRule
+
+from tools.configs import NFT_CHAIN_BASE_PATH
+
+
+logger = logging.getLogger(__name__)
+
+
+TABLE = 'firewall'
+CHAIN = 'skale'
+
+
+class NFTablesCmdFailedError(Exception):
+ pass
+
+
+class NFTablesController(IHostFirewallController):
+ plock = multiprocessing.Lock()
+ FAMILY = 'inet'
+
+ def __init__(self, chain: str, table: str = TABLE) -> None:
+ self.table = table
+ self.chain = f'skale-{chain}'
+ self._nftables = importlib.import_module('nftables')
+ self.nft = self._nftables.Nftables()
+ self.nft.set_json_output(True)
+ self.nft.set_stateless_output(True)
+
+ @classmethod
+ def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list:
+ expr = []
+
+ if rule.first_ip:
+ if rule.last_ip == rule.first_ip:
+ expr.append(
+ {
+ 'match': {
+ 'op': '==',
+ 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}},
+ 'right': f'{rule.first_ip}',
+ }
+ }
+ )
+ else:
+ expr.append(
+ {
+ 'match': {
+ 'op': '==',
+ 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}},
+ 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']},
+ }
+ }
+ )
+
+ if rule.port:
+ expr.append(
+ {
+ 'match': {
+ 'op': '==',
+ 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}},
+ 'right': rule.port,
+ }
+ }
+ )
+
+ if counter:
+ expr.append({'counter': None})
+
+ expr.append({'accept': None})
+ return expr
+
+ @classmethod
+ def expr_to_rule(self, expr: list) -> None:
+ port, first_ip, last_ip = None, None, None
+ for item in expr:
+ if 'match' in item:
+ match = item['match']
+
+ if match.get('left', {}).get('payload', {}).get('field') == 'dport':
+ port = match.get('right')
+
+ if match.get('left', {}).get('payload', {}).get('field') == 'saddr':
+ right = match.get('right')
+ if isinstance(right, str):
+ first_ip = right
+ else:
+ first_ip, last_ip = right['range']
+
+ if any([port, first_ip, last_ip]):
+ return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip)
+
+ def _compose_json(self, commands: list[dict]) -> dict:
+ json_cmd = {'nftables': commands}
+ self.nft.json_validate(json_cmd)
+ return json_cmd
+
+ def create_table(self) -> None:
+ if not self.has_table(self.table):
+ return self.run_cmd(f'add table inet {self.table}')
+
+ def has_drop_rule(self, first_port: int, last_port: int) -> bool:
+ expr = [
+ {
+ 'match': {
+ 'op': '==',
+ 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}},
+ 'right': {'range': [first_port, last_port]},
+ }
+ },
+ {'counter': None},
+ {'drop': None},
+ ]
+
+ return self.expr_to_rule(expr) in self.get_rules_by_policy(policy='drop')
+
+ def add_schain_drop_rule(self, first_port: int, last_port: int) -> None:
+ if not self.has_drop_rule(first_port, last_port):
+ expr = [
+ {
+ 'match': {
+ 'op': '==',
+ 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}},
+ 'right': {'range': [first_port, last_port]},
+ }
+ },
+ {'counter': None},
+ {'drop': None},
+ ]
+
+ cmd = {
+ 'nftables': [
+ {
+ 'add': {
+ 'rule': {
+ 'family': self.FAMILY,
+ 'table': self.table,
+ 'chain': self.chain,
+ 'expr': expr,
+ }
+ }
+ }
+ ]
+ }
+ self.run_json_cmd(cmd)
+ logger.info('Added drop rule for chain %s', self.chain)
+
+ def create_chain(self, first_port: int, last_port: int) -> None:
+ if not self.has_chain(self.chain):
+ logger.info('Creating chain %s', self.chain)
+ self.run_json_cmd(
+ self._compose_json(
+ [
+ {
+ 'add': {
+ 'chain': {
+ 'family': self.FAMILY,
+ 'table': self.table,
+ 'name': self.chain,
+ 'hook': 'input',
+ 'type': 'filter',
+ 'prio': 0,
+ 'policy': 'accept',
+ }
+ }
+ }
+ ]
+ )
+ )
+ self.add_schain_drop_rule(first_port, last_port)
+ self.save_rules()
+
+ def delete_chain(self) -> None:
+ if self.has_chain(self.chain):
+ logger.info('Removing chain %s', self.chain)
+ self.run_json_cmd(
+ self._compose_json(
+ [
+ {
+ 'delete': {
+ 'chain': {
+ 'family': self.FAMILY,
+ 'table': self.table,
+ 'name': self.chain
+ }
+ }
+ }
+ ]
+ )
+ )
+
+ @property
+ def chains(self) -> list[dict]:
+ output = self.run_cmd('list chains')
+ if output[0] != 0:
+ raise NFTablesCmdFailedError(output)
+ parsed = json.loads(output[1])['nftables']
+ return [record['chain']['name'] for record in parsed if 'chain' in record]
+
+ @property
+ def tables(self) -> list[dict]:
+ output = self.run_cmd('list tables')
+ if output[0] != 0:
+ raise NFTablesCmdFailedError(output)
+ parsed = json.loads(output[1])['nftables']
+ return [record['table']['name'] for record in parsed if 'table' in record]
+
+ def run_json_cmd(self, cmd: dict) -> tuple:
+ logger.debug('NFTables json cmd %s', cmd)
+ with self.plock:
+ return self.nft.json_cmd(cmd)
+
+ def run_cmd(self, cmd: str) -> tuple:
+ logger.debug('NFTables cmd %s', cmd)
+ with self.plock:
+ return self.nft.cmd(cmd)
+
+ def has_chain(self, chain: str) -> bool:
+ return chain in self.chains
+
+ def has_table(self, table: str) -> bool:
+ return table in self.tables
+
+ def add_rule(self, rule: SChainRule) -> None:
+ if self.has_rule(rule):
+ return
+ expr = self.rule_to_expr(rule)
+
+ json_cmd = self._compose_json(
+ [
+ {
+ 'insert': {
+ 'rule': {
+ 'family': self.FAMILY,
+ 'table': self.table,
+ 'chain': self.chain,
+ 'expr': expr,
+ }
+ }
+ }
+ ]
+ )
+
+ rc, _, error = self.run_json_cmd(json_cmd)
+ if rc != 0:
+ raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}')
+
+ def remove_rule(self, rule: SChainRule) -> None:
+ if self.has_rule(rule):
+ expr = self.rule_to_expr(rule, counter=False)
+
+ output = None
+ rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}')
+ if rc != 0:
+ raise NFTablesCmdFailedError(f'Failed to list rules: {error}')
+
+ current_rules = json.loads(output)
+
+ handle = None
+ for item in current_rules.get('nftables', []):
+ if 'rule' in item:
+ rule_data = item['rule']
+ rule_expr = list(
+ filter(lambda statement: 'counter' not in statement, rule_data['expr'])
+ )
+ if expr == rule_expr:
+ handle = rule_data.get('handle')
+ break
+
+ if handle is None:
+ raise NFTablesCmdFailedError('Rule not found')
+
+ json_cmd = self._compose_json(
+ [
+ {
+ 'delete': {
+ 'rule': {
+ 'family': self.FAMILY,
+ 'table': self.table,
+ 'chain': self.chain,
+ 'handle': handle,
+ }
+ }
+ }
+ ]
+ )
+
+ rc, _, error = self.run_json_cmd(json_cmd)
+ if rc != 0:
+ raise NFTablesCmdFailedError(f'Failed to delete rule: {error}')
+
+ @property # type: ignore
+ def rules(self) -> Iterable[SChainRule]:
+ return self.get_rules_by_policy(policy='accept')
+
+ def has_rule(self, rule: SChainRule) -> bool:
+ return rule in self.rules
+
+ def get_rules_by_policy(self, policy: str) -> list[SChainRule]:
+ output = None
+ rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}')
+ if output == '':
+ return []
+
+ data = json.loads(output)
+ rules = []
+
+ for item in data.get('nftables', []):
+ if 'rule' in item:
+ plain_rule = item['rule']
+ expr = plain_rule.get('expr', [])
+ if {policy: None} in expr:
+ rule = self.expr_to_rule(expr)
+ if rule:
+ rules.append(rule)
+ logger.debug('Rules for policy %s: %s', policy, rules)
+ return rules
+
+ def get_plain_chain_rules(self) -> str:
+ self.nft.set_json_output(False)
+ output = ''
+ try:
+ rc, output, error = self.run_cmd(
+ f'list chain {self.FAMILY} {self.table} {self.chain}'
+ )
+ if rc != 0:
+ raise NFTablesCmdFailedError(f"Failed to get table content: {error}")
+ finally:
+ self.nft.set_json_output(True)
+
+ lines = output.split('\n')
+ # cleanup table header
+ if lines[-1] == '':
+ lines = lines[1:-2]
+ else:
+ lines = lines[1:-1]
+
+ # remove leading tab
+ lines = list(map(lambda line: line[1:], lines))
+ # Adding new line at the end to prevent validation failure
+ lines.append('')
+ output = '\n'.join(lines)
+ return output
+
+ @property
+ def nft_chain_path(self) -> str:
+ return os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf')
+
+ def save_rules(self) -> None:
+ logger.info('Saving the firewall rules for chain %s', self.chain)
+ chain_rules = self.get_plain_chain_rules()
+ with open(self.nft_chain_path, 'w') as nft_chain_file:
+ nft_chain_file.write(chain_rules)
+
+ def get_saved_rules(self) -> str:
+ if not os.path.isfile(self.nft_chain_path):
+ return ''
+ with open(self.nft_chain_path, 'r') as nft_chain_file:
+ return nft_chain_file.read()
+
+ def remove_saved_rules(self) -> None:
+ if os.path.isfile(self.nft_chain_path):
+ os.remove(self.nft_chain_path)
+
+ def cleanup(self) -> None:
+ self.delete_chain()
diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py
index 51e8920a..68620526 100644
--- a/core/schains/firewall/rule_controller.py
+++ b/core/schains/firewall/rule_controller.py
@@ -23,7 +23,7 @@
from functools import wraps
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar
-from .firewall_manager import IptablesSChainFirewallManager
+from .firewall_manager import IptablesSChainFirewallManager, NFTSchainFirewallManager
from .types import (
IFirewallManager,
IpRange,
@@ -202,9 +202,6 @@ def sync(self) -> None:
logger.debug('Syncing firewall rules with %s', erules)
self.firewall_manager.update_rules(erules)
- def cleanup(self) -> None:
- self.firewall_manager.flush()
-
class IptablesSChainRuleController(SChainRuleController):
@configured_only
@@ -214,3 +211,36 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager:
self.base_port, # type: ignore
self.base_port + self.ports_per_schain - 1 # type: ignore
)
+
+ @configured_only
+ def is_persistent(self) -> bool:
+ return True
+
+ @configured_only
+ def is_inited(self) -> bool:
+ return True
+
+ @configured_only
+ def cleanup(self) -> None:
+ self.firewall_manager.cleanup()
+
+
+class NFTSchainRuleController(SChainRuleController):
+ @configured_only
+ def create_firewall_manager(self) -> NFTSchainFirewallManager:
+ return NFTSchainFirewallManager(
+ self.name,
+ self.base_port, # type: ignore
+ self.base_port + self.ports_per_schain - 1 # type: ignore
+ )
+
+ @configured_only
+ def is_persistent(self) -> bool:
+ return self.firewall_manager.rules_saved()
+
+ @configured_only
+ def is_inited(self) -> bool:
+ return self.firewall_manager.base_config_applied()
+
+ def cleanup(self) -> None:
+ self.firewall_manager.cleanup()
diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py
index 65ba8885..c30bfc11 100644
--- a/core/schains/firewall/types.py
+++ b/core/schains/firewall/types.py
@@ -88,6 +88,14 @@ def rules(self) -> Iterable[SChainRule]: # pragma: no cover
def has_rule(self, rule: SChainRule) -> bool: # pragma: no cover
pass
+ @abstractmethod
+ def save_rules(self) -> None: # pragma: no cover
+ pass
+
+ @abstractmethod
+ def cleanup(self) -> None: # pragma: no cover
+ pass
+
class IFirewallManager(ABC):
@property
@@ -100,7 +108,7 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: # pragma: no cover
pass
@abstractmethod
- def flush(self) -> None: # pragma: no cover # noqa
+ def cleanup(self) -> None: # pragma: no cover # noqa
pass
@@ -131,3 +139,11 @@ def sync(self) -> None: # pragma: no cover
@abstractmethod
def cleanup(self) -> None: # pragma: no cover
pass
+
+ @abstractmethod
+ def is_persistent(self) -> bool: # pragma: no cover
+ pass
+
+ @abstractmethod
+ def is_inited(self) -> bool: # pragma: no cover
+ pass
diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py
index 737361e1..0788c6df 100644
--- a/core/schains/firewall/utils.py
+++ b/core/schains/firewall/utils.py
@@ -25,7 +25,8 @@
from skale import Skale
from .types import IpRange
-from .rule_controller import IptablesSChainRuleController
+from .nftables import NFTablesController
+from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController
logger = logging.getLogger(__name__)
@@ -37,6 +38,22 @@ def get_default_rule_controller(
own_ip: Optional[str] = None,
node_ips: List[str] = [],
sync_agent_ranges: Optional[List[IpRange]] = []
+) -> IptablesSChainRuleController:
+ return get_nftables_rule_controller(
+ name=name,
+ base_port=base_port,
+ own_ip=own_ip,
+ node_ips=node_ips,
+ sync_agent_ranges=sync_agent_ranges
+ )
+
+
+def get_iptables_rule_controller(
+ name: str,
+ base_port: Optional[int] = None,
+ own_ip: Optional[str] = None,
+ node_ips: List[str] = [],
+ sync_agent_ranges: Optional[List[IpRange]] = []
) -> IptablesSChainRuleController:
sync_agent_ranges = sync_agent_ranges or []
logger.info('Creating rule controller for %s', name)
@@ -50,6 +67,25 @@ def get_default_rule_controller(
)
+def get_nftables_rule_controller(
+ name: str,
+ base_port: Optional[int] = None,
+ own_ip: Optional[str] = None,
+ node_ips: List[str] = [],
+ sync_agent_ranges: Optional[List[IpRange]] = []
+) -> NFTSchainRuleController:
+ sync_agent_ranges = sync_agent_ranges or []
+ logger.info('Creating rule controller for %s', name)
+ logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges)
+ return NFTSchainRuleController(
+ name=name,
+ base_port=base_port,
+ own_ip=own_ip,
+ node_ips=node_ips,
+ sync_ip_ranges=sync_agent_ranges
+ )
+
+
def get_sync_agent_ranges(skale: Skale) -> List[IpRange]:
sync_agent_ranges = []
rnum = skale.sync_manager.get_ip_ranges_number()
@@ -66,3 +102,9 @@ def save_sync_ranges(sync_agent_ranges: List[IpRange], path: str) -> None:
def ranges_from_plain_tuples(plain_ranges: List[Tuple]) -> List[IpRange]:
return list(sorted(map(lambda r: IpRange(*r), plain_ranges)))
+
+
+def cleanup_firewall_for_schain(schain_name: str) -> None:
+ nft = NFTablesController(chain=schain_name)
+ nft.cleanup()
+ nft.remove_saved_rules()
diff --git a/tests.Dockerfile b/tests.Dockerfile
index b31db00e..75b72f63 100644
--- a/tests.Dockerfile
+++ b/tests.Dockerfile
@@ -1,3 +1,7 @@
FROM admin:base
-RUN pip3 install --no-cache-dir -r requirements-dev.txt
+RUN apt update && apt install -y nftables python3-nftables
+
+RUN pip3 install -r requirements-dev.txt
+
+ENV PYTHONPATH=${PYTHONPATH}:/usr/lib/python3/dist-packages/
diff --git a/tests/conftest.py b/tests/conftest.py
index 973a375e..ac4b92ea 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -631,3 +631,13 @@ def ncli_status(_schain_name):
yield init_node_cli_status(_schain_name)
finally:
shutil.rmtree(schain_dir_path, ignore_errors=True)
+
+
+@pytest.fixture()
+def nft_chain_folder():
+ path = '/etc/nft.conf.d/skale/chains'
+ try:
+ os.makedirs(path)
+ yield path
+ finally:
+ shutil.rmtree(path)
diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py
index c2473e16..ea211eab 100644
--- a/tests/firewall/default_rule_controller_test.py
+++ b/tests/firewall/default_rule_controller_test.py
@@ -1,26 +1,28 @@
+
+import concurrent.futures
import mock
+
import pytest
-import concurrent.futures
from skale.schain_config import PORTS_PER_SCHAIN # noqa
-from core.schains.firewall import IptablesController
+from core.schains.firewall import NFTablesController
from core.schains.firewall.utils import get_default_rule_controller
from core.schains.firewall.types import IpRange, SkaledPorts
-from tests.firewall.iptables_test import get_rules_through_subprocess
+
from tools.helper import run_cmd
@pytest.fixture
def refresh():
- run_cmd(['iptables', '-F'])
+ run_cmd(['nft', 'flush', 'ruleset'])
try:
yield
finally:
- run_cmd(['iptables', '-F'])
+ run_cmd(['nft', 'flush', 'ruleset'])
-def test_get_default_rule_controller():
+def test_get_default_rule_controller(nft_chain_folder):
own_ip = '3.3.3.3'
node_ips = ['1.1.1.1', '2.2.2.2', '3.3.3.3', '4.4.4.4']
base_port = 10064
@@ -35,6 +37,9 @@ def test_get_default_rule_controller():
node_ips,
sync_ip_range
)
+
+ assert rc.is_inited()
+ assert rc.is_persistent()
assert rc.actual_rules() == []
rc.sync()
assert rc.expected_rules() == rc.actual_rules()
@@ -49,6 +54,10 @@ def test_get_default_rule_controller():
assert hm.add_rule.call_count == 0
assert hm.remove_rule.call_count == 0
+ rc.cleanup()
+ assert not rc.is_inited()
+ assert not rc.is_persistent()
+
def sync_rules(*args):
rc = get_default_rule_controller(*args)
@@ -58,22 +67,6 @@ def sync_rules(*args):
return False
-def parse_plain_rule(plain_rule):
- first_ip, last_ip, port = None, None, None
- pr = plain_rule.split()
- if '--src-range' in pr:
- srange = pr[11]
- first_ip, last_ip = srange.split('-')
- port = pr[7]
- elif '-s' in pr:
- first_ip = last_ip = pr[3][:-3]
- port = pr[9]
- elif '--dport' in pr:
- port = pr[7]
-
- return first_ip, last_ip, int(port)
-
-
def run_concurrent_rc_syncing(
node_number,
schain_number,
@@ -147,59 +140,50 @@ def run_concurrent_rc_syncing(
else:
assert not r
- pr = get_rules_through_subprocess(unique=False)[3:]
- rules = [parse_plain_rule(r) for r in pr]
+ controllers = [NFTablesController(chain=name) for name in schain_names]
+ rules = []
+ for controller in controllers:
+ rules.extend(controller.rules)
+
+ print([r.port for r in rules])
+ print([r.first_ip for r in rules])
- c = IptablesController()
# Check that all ip rules are there
for ip in node_ips:
if ip != own_ip:
assert sum(
- map(lambda x: x[0] == ip, rules)
- ) == 5 * schain_number, ip
- assert sum(
- map(lambda x: x.first_ip == ip, c.rules)
+ map(lambda x: x.first_ip == ip, rules)
) == 5 * schain_number, ip
# Check that all internal ports rules are there except CATCHUP
for p in internal_ports:
- assert sum(map(lambda x: x[2] == p, rules)) == node_number - 1, p
- assert sum(map(lambda x: x.port == p, c.rules)) == node_number - 1, p
+ assert sum(map(lambda x: x.port == p, rules)) == node_number - 1, p
# Check CATCHUP rules including sync agents rules
catchup_e_number = node_number + sync_agent_ranges_number - 1
for p in catchup_ports:
- assert sum(map(lambda x: x[2] == p, rules)) == catchup_e_number, p
- assert sum(map(lambda x: x.port == p, c.rules)) == catchup_e_number, p
+ assert sum(map(lambda x: x.port == p, rules)) == catchup_e_number, p
# Check ZMQ rules including sync agents rules
zmq_e_number = node_number + sync_agent_ranges_number - 1
for p in zmq_ports:
- assert sum(map(lambda x: x[2] == p, rules)) == zmq_e_number, p
- assert sum(map(lambda x: x.port == p, c.rules)) == zmq_e_number, p
+ assert sum(map(lambda x: x.port == p, rules)) == zmq_e_number, p
# Check sync ip ranges rules
for r in sync_agent_ranges:
assert sum(
- map(lambda x: x[0] == r.start_ip, rules)
- ) == schain_number * 2, ip
- assert sum(
- map(lambda x: x.first_ip == r.start_ip, c.rules)
- ) == schain_number * 2, ip
- assert sum(
- map(lambda x: x[1] == r.end_ip, rules)
+ map(lambda x: x.first_ip == r.start_ip, rules)
) == schain_number * 2, ip
assert sum(
- map(lambda x: x.last_ip == r.end_ip, c.rules)
+ map(lambda x: x.last_ip == r.end_ip, rules)
) == schain_number * 2, ip
for port in public_ports:
- assert sum(map(lambda x: x[2] == port, rules)) == 1, port
- assert sum(map(lambda x: x.port == port, c.rules)) == 1, port
+ assert sum(map(lambda x: x.port == port, rules)) == 1, port
@pytest.mark.parametrize('attempt', range(5))
-def test_concurrent_rc_behavior_no_refresh(attempt):
+def test_concurrent_rc_behavior_no_refresh(attempt, nft_chain_folder):
node_number = 16
schain_number = 8
own_ip = '1.1.1.1'
@@ -215,7 +199,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt):
@pytest.mark.parametrize('attempt', range(5))
-def test_concurrent_rc_behavior_with_refresh(attempt, refresh):
+def test_concurrent_rc_behavior_with_refresh(attempt, refresh, nft_chain_folder):
node_number = 16
schain_number = 8
own_ip = '1.1.1.1'
diff --git a/tests/firewall/firewall_manager_test.py b/tests/firewall/firewall_manager_test.py
index 719ad1bf..04203acc 100644
--- a/tests/firewall/firewall_manager_test.py
+++ b/tests/firewall/firewall_manager_test.py
@@ -53,7 +53,7 @@ def test_firewall_manager_update_existed():
assert fm.host_controller.remove_rule.call_count == 0
-def test_firewall_manager_flush():
+def test_firewall_manager_cleanup():
fm = SChainTestFirewallManager('test', 10000, 10064)
rules = [
SChainRule(10000, '2.2.2.2'),
@@ -63,6 +63,6 @@ def test_firewall_manager_flush():
fm.add_rules(rules)
fm.host_controller.add_rule(SChainRule(10072, '2.2.2.2'))
- fm.flush()
+ fm.cleanup()
assert list(fm.rules) == []
assert fm.host_controller.has_rule(SChainRule(10072, '2.2.2.2'))
diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py
new file mode 100644
index 00000000..e77af1fa
--- /dev/null
+++ b/tests/firewall/nftables_test.py
@@ -0,0 +1,161 @@
+import concurrent.futures
+import importlib
+import os
+import time
+
+import pytest
+
+from core.schains.firewall.nftables import NFTablesController, NFT_CHAIN_BASE_PATH
+from core.schains.firewall.types import SChainRule
+from core.schains.firewall.utils import cleanup_firewall_for_schain
+from tools.helper import run_cmd
+
+
+@pytest.fixture
+def nf_test_tables():
+ nft = importlib.import_module('nftables').Nftables()
+ nft.cmd('flush ruleset')
+ return nft
+
+
+@pytest.fixture
+def filter_table(nf_test_tables):
+ print(nf_test_tables.cmd('add table inet firewall'))
+
+
+@pytest.fixture
+def custom_chain(nf_test_tables, filter_table):
+ name = 'test-chain'
+ nf_test_tables.cmd(f'add chain inet firewall skale-{name}')
+ return name
+
+
+def test_nftables_controller(custom_chain):
+ nft_controller = NFTablesController(chain='test-chain')
+ rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2')
+ rule_b = SChainRule(10001, '3.3.3.3')
+ nft_controller.add_rule(rule_a)
+ nft_controller.add_rule(rule_b)
+ assert nft_controller.has_rule(rule_a)
+ assert nft_controller.has_rule(rule_b)
+ rules = list(nft_controller.rules)
+ assert sorted(rules) == sorted([rule_b, rule_a]), (rules, sorted([rule_b, rule_a]))
+ nft_controller.remove_rule(rule_a)
+ assert not nft_controller.has_rule(rule_a)
+ assert nft_controller.has_rule(rule_b)
+ nft_controller.remove_rule(rule_b)
+ assert not nft_controller.has_rule(rule_a)
+
+
+def test_nftables_controller_duplicates(custom_chain):
+ rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2')
+ manager = NFTablesController(chain='test-chain')
+ manager.add_rule(rule_a)
+ rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4')
+ manager.add_rule(rule_b)
+ assert sorted(list(manager.rules)) == sorted([
+ SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'),
+ SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2')
+ ])
+ assert manager.has_rule(rule_b)
+ manager.add_rule(rule_b)
+ assert manager.has_rule(rule_b)
+ assert sorted(list(manager.rules)) == sorted([
+ SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'),
+ SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2')
+ ])
+ manager.remove_rule(rule_b)
+ assert list(manager.rules) == [
+ SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2')
+ ]
+
+
+def test_create_delete_chain(filter_table, nft_chain_folder):
+ chain_name = 'test-chain'
+
+ output = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8')
+ output == 'table inet firewall {\n}\n'
+ nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf')
+ assert not os.path.isfile(nft_chain_path)
+
+ manager = NFTablesController(chain=chain_name)
+ manager.create_chain(first_port=10000, last_port=10063)
+
+ chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8')
+ assert chains == 'table inet firewall {\n\tchain skale-test-chain {\n\t\ttype filter hook input priority filter; policy accept;\n\t}\n}\n' # noqa
+ assert os.path.isfile(nft_chain_path)
+
+ manager.cleanup()
+ chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8')
+ assert chains == 'table inet firewall {\n}\n'
+ assert os.path.isfile(nft_chain_path)
+
+ manager.remove_saved_rules()
+ assert not os.path.isfile(nft_chain_path)
+
+
+def test_saved_rules(filter_table, nft_chain_folder):
+ chain_name = 'test-chain'
+ nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf')
+
+ manager = NFTablesController(chain=chain_name)
+ assert not os.path.isfile(nft_chain_path)
+ manager.create_chain(first_port=10000, last_port=10063)
+ assert os.path.isfile(nft_chain_path)
+ assert manager.get_saved_rules() == 'chain skale-test-chain {\n\ttype filter hook input priority filter; policy accept;\n\ttcp dport 10000-10063 counter drop\n}\n' # noqa
+
+ assert os.path.isfile(nft_chain_path)
+
+ manager.remove_saved_rules()
+ assert not os.path.isfile(nft_chain_path)
+
+
+def test_cleanup_firewall_for_schain(filter_table, nft_chain_folder):
+ chain_name = 'test-chain'
+ nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf')
+
+ manager = NFTablesController(chain=chain_name)
+ manager.create_chain(first_port=10000, last_port=10063)
+
+ cleanup_firewall_for_schain(schain_name=chain_name)
+ chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8')
+ assert chains == 'table inet firewall {\n}\n'
+ assert not os.path.isfile(nft_chain_path)
+
+
+def add_remove_rule(srule, refresh):
+ manager = NFTablesController(chain='test')
+ manager.add_rule(srule)
+ time.sleep(1)
+ if not manager.has_rule(srule):
+ return False
+ time.sleep(1)
+ manager.remove_rule(srule)
+ return True
+
+
+def generate_srules(number=5):
+ return [
+ SChainRule(
+ 10000 + 1,
+ f'{i}.{i}.{i}.{i}', f'{i + 1}.{i + 1}.{i + 1}.{i + 1}'
+ )
+ for i in range(1, number * 2, 2)
+ ]
+
+
+def test_nftables_manager_parallel(custom_chain):
+ srules = generate_srules(number=12)
+
+ futures = []
+ with concurrent.futures.ProcessPoolExecutor(max_workers=12) as executor:
+ futures = [
+ executor.submit(add_remove_rule, srule)
+ for srule in srules
+ ]
+
+ for future in concurrent.futures.as_completed(futures):
+ assert future.result
+ manager = NFTablesController(custom_chain)
+ time.sleep(10)
+ assert len(list(manager.rules)) == 0
diff --git a/tests/schains/checks_test.py b/tests/schains/checks_test.py
index f0d67f32..4ac14836 100644
--- a/tests/schains/checks_test.py
+++ b/tests/schains/checks_test.py
@@ -180,6 +180,8 @@ def test_volume_check(schain_checks, sample_false_checks, dutils):
def test_firewall_rules_check(schain_checks, rules_unsynced_checks):
schain_checks.rc.sync()
+ res = schain_checks.firewall_rules
+ print(res.data)
assert schain_checks.firewall_rules
assert not rules_unsynced_checks.firewall_rules.status
diff --git a/tests/schains/cleaner_test.py b/tests/schains/cleaner_test.py
index d16b41fd..ea45b25b 100644
--- a/tests/schains/cleaner_test.py
+++ b/tests/schains/cleaner_test.py
@@ -239,7 +239,8 @@ def test_get_schains_on_node(schain_dirs_for_monitor,
]).issubset(set(result))
-def test_remove_schain(skale, schain_db, node_config, dutils):
+@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain')
+def test_remove_schain(cleanup_firewall_for_schain, skale, schain_db, node_config, dutils):
schain_name = schain_db
remove_schain(skale, node_config.id, schain_name, msg='Test remove_schain', dutils=dutils)
container_name = SCHAIN_CONTAINER_NAME_TEMPLATE.format(schain_name)
@@ -250,7 +251,9 @@ def test_remove_schain(skale, schain_db, node_config, dutils):
assert record.is_deleted is True
+@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain')
def test_cleanup_schain(
+ cleanup_firewall_rules,
schain_db,
node_config,
schain_on_contracts,
diff --git a/tests/utils.py b/tests/utils.py
index dc33bf91..ac1f7f7d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -222,11 +222,20 @@ def rules(self):
def has_rule(self, srule):
return srule in self._rules
+ def save_rules(self):
+ pass
+
+ def cleanup(self):
+ pass
+
class SChainTestFirewallManager(SChainFirewallManager):
def create_host_controller(self):
return HostTestFirewallController()
+ def cleanup(self):
+ self.remove_rules(self.rules)
+
class SChainTestRuleController(SChainRuleController):
def create_firewall_manager(self):
@@ -236,6 +245,15 @@ def create_firewall_manager(self):
self.base_port + self.ports_per_schain
)
+ def is_persistent(self) -> bool:
+ return True
+
+ def is_inited(self) -> bool:
+ return True
+
+ def cleanup(self) -> None:
+ self.firewall_manager.cleanup()
+
def get_test_rule_controller(
name,
diff --git a/tools/configs/__init__.py b/tools/configs/__init__.py
index 4794de04..1e2fd423 100644
--- a/tools/configs/__init__.py
+++ b/tools/configs/__init__.py
@@ -106,3 +106,6 @@
SYNC_NODE = os.getenv('SYNC_NODE') == 'True'
DOCKER_NODE_CONFIG_FILEPATH = os.path.join(NODE_DATA_PATH, 'docker.json')
+
+NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/skale/chains'
+NFT_CHAIN_CONFIG_WILDCARD = os.path.join(NFT_CHAIN_BASE_PATH, '*')