From e1a6fc7054f487dbe6ea7ff05846a797d0a1f6ab Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Wed, 13 Nov 2024 19:15:09 +0200 Subject: [PATCH 01/10] Move project to `uv` Add typing Add pre-commit Reformat code Make code more secure Update workflow python_simplified.yml Update workflow python_detailed.yml Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- .github/workflows/python_detailed.yml | 45 +-- .github/workflows/python_simplified.yml | 32 +- .gitignore | 3 +- .pre-commit-config.yaml | 38 +++ Makefile | 37 +++ examples/__init__.py | 0 examples/kem.py | 51 +-- examples/rand.py | 25 +- examples/sig.py | 50 +-- oqs/__init__.py | 36 ++- oqs/oqs.py | 402 +++++++++++++++--------- oqs/rand.py | 24 +- pyproject.toml | 123 +++++++- setup.py | 7 - tests/__init__.py | 0 tests/test_kem.py | 43 +-- tests/test_sig.py | 51 +-- 17 files changed, 648 insertions(+), 319 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 Makefile create mode 100644 examples/__init__.py delete mode 100644 setup.py create mode 100644 tests/__init__.py diff --git a/.github/workflows/python_detailed.yml b/.github/workflows/python_detailed.yml index 418faae..bb78951 100644 --- a/.github/workflows/python_detailed.yml +++ b/.github/workflows/python_detailed.yml @@ -2,11 +2,11 @@ name: GitHub actions detailed on: push: - branches: ["**"] + branches: [ "**" ] pull_request: - branches: ["**"] + branches: [ "**" ] repository_dispatch: - types: ["**"] + types: [ "**" ] permissions: contents: read @@ -20,21 +20,24 @@ jobs: build: strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ ubuntu-latest, macos-latest, windows-latest ] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 + - name: Install uv + uses: astral-sh/setup-uv@v3 with: - python-version: "3.10" + version: "latest" + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" + + - name: Set up Python 3.9 + run: uv python install 3.9 - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install nose2 + run: uv sync --extra dev - name: Install liboqs POSIX if: matrix.os != 'windows-latest' @@ -47,17 +50,17 @@ jobs: - name: Run examples POSIX if: matrix.os != 'windows-latest' run: | - pip install . - python examples/kem.py + uv sync --extra dev + uv run examples/kem.py echo - python examples/sig.py + uv run examples/sig.py echo - python examples/rand.py + uv run examples/rand.py - name: Run unit tests POSIX if: matrix.os != 'windows-latest' run: | - nose2 --verbose + uv run nose2 --verbose - name: Install liboqs Windows if: matrix.os == 'windows-latest' @@ -73,16 +76,16 @@ jobs: shell: cmd run: | set PATH=%PATH%;${{env.WIN_LIBOQS_INSTALL_PATH}}\bin - pip install . - python examples/kem.py + uv sync --extra dev + uv run examples/kem.py echo. - python examples/sig.py + uv run examples/sig.py echo. - python examples/rand.py + uv run examples/rand.py - name: Run unit tests Windows shell: cmd if: matrix.os == 'windows-latest' run: | set PATH=%PATH%;${{env.WIN_LIBOQS_INSTALL_PATH}}\bin - nose2 --verbose + uv run nose2 --verbose diff --git a/.github/workflows/python_simplified.yml b/.github/workflows/python_simplified.yml index 387ce6e..d41c4fe 100644 --- a/.github/workflows/python_simplified.yml +++ b/.github/workflows/python_simplified.yml @@ -2,11 +2,11 @@ name: GitHub actions simplified on: push: - branches: ["**"] + branches: [ "**" ] pull_request: - branches: ["**"] + branches: [ "**" ] repository_dispatch: - types: ["**"] + types: [ "**" ] permissions: contents: read @@ -15,25 +15,29 @@ jobs: build: strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ ubuntu-latest, macos-latest, windows-latest ] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 + - name: Install uv + uses: astral-sh/setup-uv@v3 with: - python-version: "3.10" + version: "latest" + enable-cache: true + cache-dependency-glob: "**/pyproject.toml" + + - name: Set up Python 3.9 + run: uv python install 3.9 - name: Run examples run: | - python -m pip install --upgrade pip - pip install . - python examples/kem.py - python examples/sig.py - python examples/rand.py + uv sync --extra dev + uv run examples/kem.py + uv run examples/sig.py + uv run examples/rand.py - name: Run unit tests run: | - nose2 --verbose + uv run nose2 --verbose diff --git a/.gitignore b/.gitignore index 6894774..f0096dd 100644 --- a/.gitignore +++ b/.gitignore @@ -117,4 +117,5 @@ pip-selfcheck.json pyvenv.cfg # vim -*.swp \ No newline at end of file +*.swp +/uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6c6f2db --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,38 @@ +fail_fast: false +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: "trailing-whitespace" + - id: "check-case-conflict" + - id: "check-merge-conflict" + - id: "debug-statements" + - id: "end-of-file-fixer" + - id: "mixed-line-ending" + args: [ "--fix", "crlf" ] + types: + - python + - yaml + - toml + - text + - id: "detect-private-key" + - id: "check-yaml" + - id: "check-toml" + - id: "check-json" + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.7.3 + hooks: + - id: ruff + args: [ "--fix" ] + files: "oqs" + + - id: ruff-format + files: "oqs" + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + files: "oqs" diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8a2409f --- /dev/null +++ b/Makefile @@ -0,0 +1,37 @@ +src-dir = oqs +tests-dir = tests +examples-dir = examples + +.PHONY pull: +pull: + git pull origin master + git submodule update --init --recursive + +.PHONY lint: +lint: + echo "Running ruff..." + uv run ruff check --config pyproject.toml --diff $(src-dir) $(tests-dir) $(examples-dir) + +.PHONY format: +format: + echo "Running ruff check with --fix..." + uv run ruff check --config pyproject.toml --fix --unsafe-fixes $(src-dir) $(tests-dir) $(examples-dir) + + echo "Running ruff..." + uv run ruff format --config pyproject.toml $(src-dir) $(tests-dir) $(examples-dir) + + echo "Running isort..." + uv run isort --settings-file pyproject.toml $(src-dir) $(tests-dir) $(examples-dir) + +.PHONE mypy: +mypy: + echo "Running MyPy..." + uv run mypy --config-file pyproject.toml + +.PHONY outdated: +outdated: + uv tree --outdated --universal + +.PHONY sync: +sync: + uv sync --extra dev --extra lint diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/kem.py b/examples/kem.py index a9deb2d..4a5c826 100644 --- a/examples/kem.py +++ b/examples/kem.py @@ -1,36 +1,39 @@ # Key encapsulation Python example +import logging +from pprint import pformat import oqs -from pprint import pprint -print("liboqs version:", oqs.oqs_version()) -print("liboqs-python version:", oqs.oqs_python_version()) -print("Enabled KEM mechanisms:") -kems = oqs.get_enabled_kem_mechanisms() -pprint(kems, compact=True) +logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +logger.info("liboqs version: %s", oqs.oqs_version()) +logger.info("liboqs-python version: %s", oqs.oqs_python_version()) +logger.info("Enabled KEM mechanisms: %s", pformat(oqs.get_enabled_kem_mechanisms(), compact=True)) # Create client and server with sample KEM mechanisms kemalg = "Kyber512" -with oqs.KeyEncapsulation(kemalg) as client: - with oqs.KeyEncapsulation(kemalg) as server: - print("\nKey encapsulation details:") - pprint(client.details) +with oqs.KeyEncapsulation(kemalg) as client, oqs.KeyEncapsulation(kemalg) as server: + # print("\nKey encapsulation details:") + logger.info("Client details: %s", pformat(client.details)) - # Client generates its keypair - public_key_client = client.generate_keypair() - # Optionally, the secret key can be obtained by calling export_secret_key() - # and the client can later be re-instantiated with the key pair: - # secret_key_client = client.export_secret_key() + # Client generates its keypair + public_key_client = client.generate_keypair() + # Optionally, the secret key can be obtained by calling export_secret_key() + # and the client can later be re-instantiated with the key pair: + # secret_key_client = client.export_secret_key() - # Store key pair, wait... (session resumption): - # client = oqs.KeyEncapsulation(kemalg, secret_key_client) + # Store key pair, wait... (session resumption): + # client = oqs.KeyEncapsulation(kemalg, secret_key_client) - # The server encapsulates its secret using the client's public key - ciphertext, shared_secret_server = server.encap_secret(public_key_client) + # The server encapsulates its secret using the client's public key + ciphertext, shared_secret_server = server.encap_secret(public_key_client) - # The client decapsulates the server's ciphertext to obtain the shared secret - shared_secret_client = client.decap_secret(ciphertext) + # The client decapsulates the server's ciphertext to obtain the shared secret + shared_secret_client = client.decap_secret(ciphertext) - print( - "\nShared secretes coincide:", shared_secret_client == shared_secret_server - ) + logger.info( + "Shared secretes coincide: %s", + shared_secret_client == shared_secret_server, + ) diff --git a/examples/rand.py b/examples/rand.py index 75e5c89..5af394f 100644 --- a/examples/rand.py +++ b/examples/rand.py @@ -1,22 +1,27 @@ # Various RNGs Python example - +import logging import platform # to learn the OS we're on + import oqs.rand as oqsrand # must be explicitly imported -from oqs import oqs_version, oqs_python_version +from oqs import oqs_python_version, oqs_version + +logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) -print("liboqs version:", oqs_version()) -print("liboqs-python version:", oqs_python_version()) +logger.info("liboqs version: %s", oqs_version()) +logger.info("liboqs-python version: %s", oqs_python_version()) oqsrand.randombytes_switch_algorithm("system") -print( - "{:17s}".format("System (default):"), - " ".join("{:02X}".format(x) for x in oqsrand.randombytes(32)), +logger.info( + "System (default): %s", + " ".join(f"{x:02X}" for x in oqsrand.randombytes(32)), ) # We do not yet support OpenSSL under Windows if platform.system() != "Windows": oqsrand.randombytes_switch_algorithm("OpenSSL") - print( - "{:17s}".format("OpenSSL:"), - " ".join("{:02X}".format(x) for x in oqsrand.randombytes(32)), + logger.info( + "OpenSSL: %s", + " ".join(f"{x:02X}" for x in oqsrand.randombytes(32)), ) diff --git a/examples/sig.py b/examples/sig.py index ea07508..1d84297 100644 --- a/examples/sig.py +++ b/examples/sig.py @@ -1,36 +1,40 @@ # Signature Python example +import logging +from pprint import pformat import oqs -from pprint import pprint -print("liboqs version:", oqs.oqs_version()) -print("liboqs-python version:", oqs.oqs_python_version()) -print("Enabled signature mechanisms:") -sigs = oqs.get_enabled_sig_mechanisms() -pprint(sigs, compact=True) +logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) -message = "This is the message to sign".encode() +logger.info("liboqs version: %s", oqs.oqs_version()) +logger.info("liboqs-python version: %s", oqs.oqs_python_version()) +logger.info( + "Enabled signature mechanisms: %s", + pformat(oqs.get_enabled_sig_mechanisms(), compact=True), +) + +message = b"This is the message to sign" # Create signer and verifier with sample signature mechanisms sigalg = "Dilithium2" -with oqs.Signature(sigalg) as signer: - with oqs.Signature(sigalg) as verifier: - print("\nSignature details:") - pprint(signer.details) +with oqs.Signature(sigalg) as signer, oqs.Signature(sigalg) as verifier: + logger.info("Signature details: %s", pformat(signer.details)) - # Signer generates its keypair - signer_public_key = signer.generate_keypair() - # Optionally, the secret key can be obtained by calling export_secret_key() - # and the signer can later be re-instantiated with the key pair: - # secret_key = signer.export_secret_key() + # Signer generates its keypair + signer_public_key = signer.generate_keypair() + # Optionally, the secret key can be obtained by calling export_secret_key() + # and the signer can later be re-instantiated with the key pair: + # secret_key = signer.export_secret_key() - # Store key pair, wait... (session resumption): - # signer = oqs.Signature(sigalg, secret_key) + # Store key pair, wait... (session resumption): + # signer = oqs.Signature(sigalg, secret_key) - # Signer signs the message - signature = signer.sign(message) + # Signer signs the message + signature = signer.sign(message) - # Verifier verifies the signature - is_valid = verifier.verify(message, signature, signer_public_key) + # Verifier verifies the signature + is_valid = verifier.verify(message, signature, signer_public_key) - print("\nValid signature?", is_valid) + logger.info("Valid signature? %s", is_valid) diff --git a/oqs/__init__.py b/oqs/__init__.py index f82845b..6dab0c3 100644 --- a/oqs/__init__.py +++ b/oqs/__init__.py @@ -1 +1,35 @@ -from oqs.oqs import * +from oqs.oqs import ( + OQS_SUCCESS, + OQS_VERSION, + KeyEncapsulation, + MechanismNotEnabledError, + MechanismNotSupportedError, + Signature, + get_enabled_kem_mechanisms, + get_enabled_sig_mechanisms, + get_supported_kem_mechanisms, + get_supported_sig_mechanisms, + is_kem_enabled, + is_sig_enabled, + native, + oqs_python_version, + oqs_version, +) + +__all__ = ( + "KeyEncapsulation", + "MechanismNotEnabledError", + "MechanismNotSupportedError", + "OQS_SUCCESS", + "OQS_VERSION", + "Signature", + "get_enabled_kem_mechanisms", + "get_enabled_sig_mechanisms", + "get_supported_kem_mechanisms", + "get_supported_sig_mechanisms", + "is_kem_enabled", + "is_sig_enabled", + "native", + "oqs_python_version", + "oqs_version", +) diff --git a/oqs/oqs.py b/oqs/oqs.py index fc147da..a244ed6 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -1,5 +1,5 @@ """ -Open Quantum Safe (OQS) Python wrapper for liboqs +Open Quantum Safe (OQS) Python wrapper for liboqs. The liboqs project provides post-quantum public key cryptography algorithms: https://github.com/open-quantum-safe/liboqs @@ -7,23 +7,37 @@ This module provides a Python 3 interface to liboqs. """ +from __future__ import annotations + import ctypes as ct # to call native import ctypes.util as ctu import importlib.metadata # to determine module version at runtime -import os # to run OS commands (install liboqs on demand if not found) +import logging import platform # to learn the OS we're on +import subprocess import sys import tempfile # to install liboqs on demand import time import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, cast + +if TYPE_CHECKING: + from collections.abc import Sequence + from types import TracebackType + +TKeyEncapsulation = TypeVar("TKeyEncapsulation", bound="KeyEncapsulation") +TSignature = TypeVar("TSignature", bound="Signature") +logger = logging.getLogger(__name__) -def oqs_python_version(): + +def oqs_python_version() -> str | None: """liboqs-python version string.""" try: result = importlib.metadata.version("liboqs-python") except importlib.metadata.PackageNotFoundError: - warnings.warn("Please install liboqs-python using pip install") + warnings.warn("Please install liboqs-python using pip install", stacklevel=2) return None return result @@ -33,115 +47,155 @@ def oqs_python_version(): OQS_VERSION = oqs_python_version() -def _countdown(seconds): +def _countdown(seconds: int) -> None: while seconds > 0: - print(seconds, end=" ") + logger.info("Installing in %s seconds...", seconds) sys.stdout.flush() seconds -= 1 time.sleep(1) - print() -def _load_shared_obj(name, additional_searching_paths=None): - """Attempts to load shared library.""" - paths = [] +def _load_shared_obj( + name: str, + additional_searching_paths: Sequence[Path] | None = None, +) -> ct.CDLL: + """Attempt to load shared library.""" + paths: list[Path] = [] dll = ct.windll if platform.system() == "Windows" else ct.cdll # Search additional path, if any if additional_searching_paths: for path in additional_searching_paths: if platform.system() == "Darwin": - paths.append( - os.path.abspath(path) + os.path.sep + "lib" + name + ".dylib" - ) + paths.append(path.absolute() / Path(f"lib{name}").with_suffix(".dylib")) elif platform.system() == "Windows": - paths.append(os.path.abspath(path) + os.path.sep + name + ".dll") + paths.append(path.absolute() / Path(name).with_suffix(".dll")) # Does not work # os.environ["PATH"] += os.path.abspath(path) else: # Linux/FreeBSD/UNIX - paths.append(os.path.abspath(path) + os.path.sep + "lib" + name + ".so") + paths.append(path.absolute() / Path(f"lib{name}").with_suffix(".so")) # https://stackoverflow.com/questions/856116/changing-ld-library-path-at-runtime-for-ctypes # os.environ["LD_LIBRARY_PATH"] += os.path.abspath(path) # Search typical locations - try: - paths.insert(0, ctu.find_library(name)) - except FileNotFoundError: - pass - try: - paths.insert(0, ctu.find_library("lib" + name)) - except FileNotFoundError: - pass + + if found_lib := ctu.find_library(name): + paths.insert(0, Path(found_lib)) + + if found_lib := ctu.find_library("lib" + name): + paths.insert(0, Path(found_lib)) for path in paths: if path: try: - lib = dll.LoadLibrary(path) - return lib + lib: ct.CDLL = dll.LoadLibrary(str(path)) except OSError: pass - raise RuntimeError("No " + name + " shared libraries found") + else: + return lib + + msg = f"No {name} shared libraries found" + raise RuntimeError(msg) -def _install_liboqs(target_directory, oqs_version=None): - """Install liboqs version oqs_version (if None, installs latest at HEAD) in the target_directory.""" +def _install_liboqs(target_directory: Path, oqs_version_to_install: str | None = None) -> None: + """Install liboqs version oqs_version (if None, installs latest at HEAD) in the target_directory.""" # noqa: E501 with tempfile.TemporaryDirectory() as tmpdirname: - oqs_install_str = ( - "cd " - + tmpdirname - + " && git clone https://github.com/open-quantum-safe/liboqs" - ) - if oqs_version: - oqs_install_str += " --branch " + oqs_version - oqs_install_str += ( - " --depth 1 && cmake -S liboqs -B liboqs/build -DBUILD_SHARED_LIBS=ON -DOQS_BUILD_ONLY_LIB=ON -DCMAKE_INSTALL_PREFIX=" - + target_directory + oqs_install_cmd = [ + "cd", + tmpdirname, + "&&", + "git", + "clone", + "https://github.com/open-quantum-safe/liboqs", + ] + if oqs_version_to_install: + oqs_install_cmd.extend(["--branch", oqs_version_to_install]) + + oqs_install_cmd.extend( + [ + "--depth", + "1", + "&&", + "cmake", + "-S", + "liboqs", + "-B", + "liboqs/build", + "-DBUILD_SHARED_LIBS=ON", + "-DOQS_BUILD_ONLY_LIB=ON", + f"-DCMAKE_INSTALL_PREFIX={target_directory}", + ], ) + if platform.system() == "Windows": - oqs_install_str += " -DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE" - oqs_install_str += " && cmake --build liboqs/build --parallel 4 && cmake --build liboqs/build --target install" - print("liboqs not found, installing it in " + target_directory) + oqs_install_cmd.append("-DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE") + + oqs_install_cmd.extend( + [ + "&&", + "cmake", + "--build", + "liboqs/build", + "--parallel", + "4", + "&&", + "cmake", + "--build", + "liboqs/build", + "--target", + "install", + ], + ) + logger.info("liboqs not found, installing it in %s", str(target_directory)) _countdown(5) - os.system(oqs_install_str) - print("Done installing liboqs") + _retcode = subprocess.call(" ".join(oqs_install_cmd), shell=True) # noqa: S602 -def _load_liboqs(): - home_dir = os.path.expanduser("~") - oqs_install_dir = os.path.abspath(home_dir + os.path.sep + "_oqs") # $HOME/_oqs + if _retcode != 0: + logger.exception("Error installing liboqs.") + sys.exit(1) + + logger.info("Done installing liboqs") + + +def _load_liboqs() -> ct.CDLL: + home_dir = Path.home() + oqs_install_dir = home_dir / "_oqs" oqs_lib_dir = ( - os.path.abspath(oqs_install_dir + os.path.sep + "bin") # $HOME/_oqs/bin + oqs_install_dir / "bin" # $HOME/_oqs/bin if platform.system() == "Windows" - else os.path.abspath(oqs_install_dir + os.path.sep + "lib") # $HOME/_oqs/lib + else oqs_install_dir / "lib" # $HOME/_oqs/lib ) try: - _liboqs = _load_shared_obj(name="oqs", additional_searching_paths=[oqs_lib_dir]) - assert _liboqs + liboqs = _load_shared_obj(name="oqs", additional_searching_paths=[oqs_lib_dir]) + assert liboqs # noqa: S101 except RuntimeError: # We don't have liboqs, so we try to install it automatically - _install_liboqs(target_directory=oqs_install_dir, oqs_version=OQS_VERSION) + _install_liboqs(target_directory=oqs_install_dir, oqs_version_to_install=OQS_VERSION) # Try loading it again try: - _liboqs = _load_shared_obj( - name="oqs", additional_searching_paths=[oqs_lib_dir] + liboqs = _load_shared_obj( + name="oqs", + additional_searching_paths=[oqs_lib_dir], ) - assert _liboqs + assert liboqs # noqa: S101 except RuntimeError: sys.exit("Could not load liboqs shared library") - return _liboqs + return liboqs _liboqs = _load_liboqs() # Expected return value from native OQS functions -OQS_SUCCESS = 0 -OQS_ERROR = -1 +OQS_SUCCESS: Final[int] = 0 +OQS_ERROR: Final[int] = -1 -def native(): +def native() -> ct.CDLL: """Handle to native liboqs handler.""" return _liboqs @@ -150,41 +204,37 @@ def native(): native().OQS_init() -def oqs_version(): - """liboqs version string.""" +def oqs_version() -> str: + """`liboqs` version string.""" native().OQS_version.restype = ct.c_char_p - return ct.c_char_p(native().OQS_version()).value.decode("UTF-8") + return ct.c_char_p(native().OQS_version()).value.decode("UTF-8") # type: ignore[union-attr] # Warn the user if the liboqs version differs from liboqs-python version if oqs_version() != oqs_python_version(): warnings.warn( - "liboqs version {} differs from liboqs-python version {}".format( - oqs_version(), oqs_python_version() - ) + f"liboqs version {oqs_version()} differs from liboqs-python version " + f"{oqs_python_version()}", + stacklevel=2, ) class MechanismNotSupportedError(Exception): """Exception raised when an algorithm is not supported by OQS.""" - def __init__(self, alg_name): - """ - :param alg_name: requested algorithm name. - """ + def __init__(self, alg_name: str) -> None: + """:param alg_name: requested algorithm name.""" self.alg_name = alg_name - self.message = alg_name + " is not supported by OQS" + self.message = f"{alg_name} is not supported by OQS" class MechanismNotEnabledError(MechanismNotSupportedError): """Exception raised when an algorithm is supported but not enabled by OQS.""" - def __init__(self, alg_name): - """ - :param alg_name: requested algorithm name. - """ + def __init__(self, alg_name: str) -> None: + """:param alg_name: requested algorithm name.""" self.alg_name = alg_name - self.message = alg_name + " is supported but not enabled by OQS" + self.message = f"{alg_name} is supported but not enabled by OQS" class KeyEncapsulation(ct.Structure): @@ -201,7 +251,7 @@ class KeyEncapsulation(ct.Structure): free | OQS_KEM_free """ - _fields_ = [ + _fields_: ClassVar[list[tuple[str, Any]]] = [ ("method_name", ct.c_char_p), ("alg_version", ct.c_char_p), ("claimed_nist_level", ct.c_ubyte), @@ -215,9 +265,9 @@ class KeyEncapsulation(ct.Structure): ("decaps_cb", ct.c_void_p), ] - def __init__(self, alg_name, secret_key=None): + def __init__(self, alg_name: str, secret_key: int | bytes | None = None) -> None: """ - Creates new KeyEncapsulation with the given algorithm. + Create new KeyEncapsulation with the given algorithm. :param alg_name: KEM mechanism algorithm name. Enabled KEM mechanisms can be obtained with get_enabled_KEM_mechanisms(). @@ -229,8 +279,7 @@ def __init__(self, alg_name, secret_key=None): # perhaps it's a supported but not enabled alg if alg_name in _supported_KEMs: raise MechanismNotEnabledError(alg_name) - else: - raise MechanismNotSupportedError(alg_name) + raise MechanismNotSupportedError(alg_name) self._kem = native().OQS_KEM_new(ct.create_string_buffer(alg_name.encode())) @@ -247,102 +296,131 @@ def __init__(self, alg_name, secret_key=None): if secret_key: self.secret_key = ct.create_string_buffer( - secret_key, self._kem.contents.length_secret_key + secret_key, + self._kem.contents.length_secret_key, ) - def __enter__(self): + def __enter__(self: TKeyEncapsulation) -> TKeyEncapsulation: return self - def __exit__(self, ctx_type, ctx_value, ctx_traceback): + def __exit__( + self, + ctx_type: type[BaseException] | None, + ctx_value: BaseException | None, + ctx_traceback: TracebackType | None, + ) -> None: self.free() - def generate_keypair(self): + def generate_keypair(self) -> bytes | int: """ - Generates a new keypair and returns the public key. + Generate a new keypair and returns the public key. If needed, the secret key can be obtained with export_secret_key(). """ public_key = ct.create_string_buffer(self._kem.contents.length_public_key) self.secret_key = ct.create_string_buffer(self._kem.contents.length_secret_key) rv = native().OQS_KEM_keypair( - self._kem, ct.byref(public_key), ct.byref(self.secret_key) + self._kem, + ct.byref(public_key), + ct.byref(self.secret_key), ) return bytes(public_key) if rv == OQS_SUCCESS else 0 - def export_secret_key(self): - """Exports the secret key.""" + def export_secret_key(self) -> bytes: + """Export the secret key.""" return bytes(self.secret_key) - def encap_secret(self, public_key): + def encap_secret(self, public_key: int | bytes) -> tuple[bytes, bytes | int]: """ - Generates and encapsulates a secret using the provided public key. + Generate and encapsulates a secret using the provided public key. :param public_key: the peer's public key. """ my_public_key = ct.create_string_buffer( - public_key, self._kem.contents.length_public_key + public_key, + self._kem.contents.length_public_key, + ) + ciphertext: ct.Array[ct.c_char] = ct.create_string_buffer( + self._kem.contents.length_ciphertext, + ) + shared_secret: ct.Array[ct.c_char] = ct.create_string_buffer( + self._kem.contents.length_shared_secret, ) - ciphertext = ct.create_string_buffer(self._kem.contents.length_ciphertext) - shared_secret = ct.create_string_buffer(self._kem.contents.length_shared_secret) rv = native().OQS_KEM_encaps( - self._kem, ct.byref(ciphertext), ct.byref(shared_secret), my_public_key + self._kem, + ct.byref(ciphertext), + ct.byref(shared_secret), + my_public_key, ) - return bytes(ciphertext), bytes(shared_secret) if rv == OQS_SUCCESS else 0 - def decap_secret(self, ciphertext): + # TODO: What should it return? + # 1. tuple[bytes | int, bytes | int] + # 2. tuple[bytes, bytes | int] + # 3. tuple[bytes, bytes] | int + return ( + bytes(cast(bytes, ciphertext)), + bytes(cast(bytes, shared_secret)) if rv == OQS_SUCCESS else 0, + ) + + def decap_secret(self, ciphertext: int | bytes) -> bytes | int: """ - Decapsulates the ciphertext and returns the secret. + Decapsulate the ciphertext and returns the secret. :param ciphertext: the ciphertext received from the peer. """ my_ciphertext = ct.create_string_buffer( - ciphertext, self._kem.contents.length_ciphertext + ciphertext, + self._kem.contents.length_ciphertext, + ) + shared_secret: ct.Array[ct.c_char] = ct.create_string_buffer( + self._kem.contents.length_shared_secret, ) - shared_secret = ct.create_string_buffer(self._kem.contents.length_shared_secret) rv = native().OQS_KEM_decaps( - self._kem, ct.byref(shared_secret), my_ciphertext, self.secret_key + self._kem, + ct.byref(shared_secret), + my_ciphertext, + self.secret_key, ) - return bytes(shared_secret) if rv == OQS_SUCCESS else 0 + return bytes(cast(bytes, shared_secret)) if rv == OQS_SUCCESS else 0 - def free(self): + def free(self) -> None: """Releases the native resources.""" if hasattr(self, "secret_key"): native().OQS_MEM_cleanse( - ct.byref(self.secret_key), self._kem.contents.length_secret_key + ct.byref(self.secret_key), + self._kem.contents.length_secret_key, ) native().OQS_KEM_free(self._kem) - def __repr__(self): - return "Key encapsulation mechanism: " + self._kem.contents.method_name.decode() + def __repr__(self) -> str: + return f"Key encapsulation mechanism: {self._kem.contents.method_name.decode()}" native().OQS_KEM_new.restype = ct.POINTER(KeyEncapsulation) native().OQS_KEM_alg_identifier.restype = ct.c_char_p -def is_kem_enabled(alg_name): +def is_kem_enabled(alg_name: str) -> bool: """ - Returns True if the KEM algorithm is enabled. + Return True if the KEM algorithm is enabled. :param alg_name: a KEM mechanism algorithm name. """ return native().OQS_KEM_alg_is_enabled(ct.create_string_buffer(alg_name.encode())) -_KEM_alg_ids = [ - native().OQS_KEM_alg_identifier(i) for i in range(native().OQS_KEM_alg_count()) -] -_supported_KEMs = [i.decode() for i in _KEM_alg_ids] -_enabled_KEMs = [i for i in _supported_KEMs if is_kem_enabled(i)] +_KEM_alg_ids = [native().OQS_KEM_alg_identifier(i) for i in range(native().OQS_KEM_alg_count())] +_supported_KEMs: list[str] = [i.decode() for i in _KEM_alg_ids] # noqa: N816 +_enabled_KEMs: list[str] = [i for i in _supported_KEMs if is_kem_enabled(i)] # noqa: N816 -def get_enabled_kem_mechanisms(): - """Returns the list of enabled KEM mechanisms.""" +def get_enabled_kem_mechanisms() -> list[str]: + """Return the list of enabled KEM mechanisms.""" return _enabled_KEMs -def get_supported_kem_mechanisms(): - """Returns the list of supported KEM mechanisms.""" +def get_supported_kem_mechanisms() -> list[str]: + """Return the list of supported KEM mechanisms.""" return _supported_KEMs @@ -360,7 +438,7 @@ class Signature(ct.Structure): free | OQS_SIG_free """ - _fields_ = [ + _fields_: ClassVar[list[tuple[str, Any]]] = [ ("method_name", ct.c_char_p), ("alg_version", ct.c_char_p), ("claimed_nist_level", ct.c_ubyte), @@ -373,12 +451,12 @@ class Signature(ct.Structure): ("verify_cb", ct.c_void_p), ] - def __init__(self, alg_name, secret_key=None): + def __init__(self, alg_name: str, secret_key: int | bytes | None = None) -> None: """ - Creates new Signature with the given algorithm. + Create new Signature with the given algorithm. - :param alg_name: a signature mechanism algorithm name. Enabled signature mechanisms can be obtained with - get_enabled_sig_mechanisms(). + :param alg_name: a signature mechanism algorithm name. Enabled signature mechanisms can be + obtained with get_enabled_sig_mechanisms(). :param secret_key: optional, if generated by generate_keypair(). """ super().__init__() @@ -386,8 +464,7 @@ def __init__(self, alg_name, secret_key=None): # perhaps it's a supported but not enabled alg if alg_name in _supported_sigs: raise MechanismNotEnabledError(alg_name) - else: - raise MechanismNotSupportedError(alg_name) + raise MechanismNotSupportedError(alg_name) self._sig = native().OQS_SIG_new(ct.create_string_buffer(alg_name.encode())) self.details = { @@ -402,33 +479,43 @@ def __init__(self, alg_name, secret_key=None): if secret_key: self.secret_key = ct.create_string_buffer( - secret_key, self._sig.contents.length_secret_key + secret_key, + self._sig.contents.length_secret_key, ) - def __enter__(self): + def __enter__(self: TSignature) -> TSignature: return self - def __exit__(self, ctx_type, ctx_value, ctx_traceback): + def __exit__( + self, + ctx_type: type[BaseException] | None, + ctx_value: BaseException | None, + ctx_traceback: TracebackType | None, + ) -> None: self.free() - def generate_keypair(self): + def generate_keypair(self) -> bytes | int: """ - Generates a new keypair and returns the public key. + Generate a new keypair and returns the public key. If needed, the secret key can be obtained with export_secret_key(). """ - public_key = ct.create_string_buffer(self._sig.contents.length_public_key) + public_key: ct.Array[ct.c_char] = ct.create_string_buffer( + self._sig.contents.length_public_key, + ) self.secret_key = ct.create_string_buffer(self._sig.contents.length_secret_key) rv = native().OQS_SIG_keypair( - self._sig, ct.byref(public_key), ct.byref(self.secret_key) + self._sig, + ct.byref(public_key), + ct.byref(self.secret_key), ) - return bytes(public_key) if rv == OQS_SUCCESS else 0 + return bytes(cast(bytes, public_key)) if rv == OQS_SUCCESS else 0 - def export_secret_key(self): - """Exports the secret key.""" + def export_secret_key(self) -> bytes: + """Export the secret key.""" return bytes(self.secret_key) - def sign(self, message): + def sign(self, message: bytes) -> bytes | int: """ Signs the provided message and returns the signature. @@ -437,9 +524,11 @@ def sign(self, message): # Provide length to avoid extra null char my_message = ct.create_string_buffer(message, len(message)) message_len = ct.c_int(len(my_message)) - signature = ct.create_string_buffer(self._sig.contents.length_signature) + signature: ct.Array[ct.c_char] = ct.create_string_buffer( + self._sig.contents.length_signature, + ) sig_len = ct.c_int( - self._sig.contents.length_signature + self._sig.contents.length_signature, ) # initialize to maximum signature size rv = native().OQS_SIG_sign( self._sig, @@ -450,11 +539,11 @@ def sign(self, message): self.secret_key, ) - return bytes(signature[: sig_len.value]) if rv == OQS_SUCCESS else 0 + return bytes(cast(bytes, signature[: sig_len.value])) if rv == OQS_SUCCESS else 0 - def verify(self, message, signature, public_key): + def verify(self, message: bytes, signature: bytes, public_key: bytes) -> bool: """ - Verifies the provided signature on the message; returns True if valid. + Verify the provided signature on the message; returns True if valid. :param message: the signed message. :param signature: the signature on the message. @@ -468,50 +557,55 @@ def verify(self, message, signature, public_key): my_signature = ct.create_string_buffer(signature, len(signature)) sig_len = ct.c_int(len(my_signature)) my_public_key = ct.create_string_buffer( - public_key, self._sig.contents.length_public_key + public_key, + self._sig.contents.length_public_key, ) rv = native().OQS_SIG_verify( - self._sig, my_message, message_len, my_signature, sig_len, my_public_key + self._sig, + my_message, + message_len, + my_signature, + sig_len, + my_public_key, ) - return True if rv == OQS_SUCCESS else False + return rv == OQS_SUCCESS - def free(self): + def free(self) -> None: """Releases the native resources.""" if hasattr(self, "secret_key"): native().OQS_MEM_cleanse( - ct.byref(self.secret_key), self._sig.contents.length_secret_key + ct.byref(self.secret_key), + self._sig.contents.length_secret_key, ) native().OQS_SIG_free(self._sig) - def __repr__(self): - return "Signature mechanism: " + self._sig.contents.method_name.decode() + def __repr__(self) -> str: + return f"Signature mechanism: {self._sig.contents.method_name.decode()}" native().OQS_SIG_new.restype = ct.POINTER(Signature) native().OQS_SIG_alg_identifier.restype = ct.c_char_p -def is_sig_enabled(alg_name): +def is_sig_enabled(alg_name: str) -> bool: """ - Returns True if the signature algorithm is enabled. + Return True if the signature algorithm is enabled. :param alg_name: a signature mechanism algorithm name. """ return native().OQS_SIG_alg_is_enabled(ct.create_string_buffer(alg_name.encode())) -_sig_alg_ids = [ - native().OQS_SIG_alg_identifier(i) for i in range(native().OQS_SIG_alg_count()) -] +_sig_alg_ids = [native().OQS_SIG_alg_identifier(i) for i in range(native().OQS_SIG_alg_count())] _supported_sigs = [i.decode() for i in _sig_alg_ids] _enabled_sigs = [i for i in _supported_sigs if is_sig_enabled(i)] -def get_enabled_sig_mechanisms(): - """Returns the list of enabled signature mechanisms.""" +def get_enabled_sig_mechanisms() -> list[str]: + """Return the list of enabled signature mechanisms.""" return _enabled_sigs -def get_supported_sig_mechanisms(): - """Returns the list of supported signature mechanisms.""" +def get_supported_sig_mechanisms() -> list[str]: + """Return the list of supported signature mechanisms.""" return _supported_sigs diff --git a/oqs/rand.py b/oqs/rand.py index 5e7fd77..c64f50e 100644 --- a/oqs/rand.py +++ b/oqs/rand.py @@ -1,5 +1,5 @@ """ -Open Quantum Safe (OQS) Python Wrapper for liboqs +Open Quantum Safe (OQS) Python Wrapper for liboqs. The liboqs project provides post-quantum public key cryptography algorithms: https://github.com/open-quantum-safe/liboqs @@ -7,32 +7,36 @@ This module provides a Python 3 interface to libOQS RNGs. """ +import ctypes as ct + import oqs -def randombytes(bytes_to_read): +def randombytes(bytes_to_read: int) -> bytes: """ - Generates random bytes. This implementation uses either the default RNG algorithm ("system"), or whichever - algorithm has been selected by random_bytes_switch_algorithm(). + Generate random bytes. This implementation uses either the default RNG algorithm ("system"), + or whichever algorithm has been selected by random_bytes_switch_algorithm(). :param bytes_to_read: the number of random bytes to generate. :return: random bytes. """ - result = oqs.ct.create_string_buffer(bytes_to_read) - oqs.native().OQS_randombytes(result, oqs.ct.c_int(bytes_to_read)) + result = ct.create_string_buffer(bytes_to_read) + oqs.native().OQS_randombytes(result, ct.c_int(bytes_to_read)) return bytes(result) -def randombytes_switch_algorithm(alg_name): +def randombytes_switch_algorithm(alg_name: str) -> None: """ - Switches the core OQS_randombytes to use the specified algorithm. See liboqs headers for more details. + Switches the core OQS_randombytes to use the specified algorithm. See liboqs + headers for more details. :param alg_name: algorithm name, possible values are "system" and "OpenSSL". """ if ( oqs.native().OQS_randombytes_switch_algorithm( - oqs.ct.create_string_buffer(alg_name.encode()) + ct.create_string_buffer(alg_name.encode()), ) != oqs.OQS_SUCCESS ): - raise RuntimeError("Can not switch algorithm") + msg = "Can not switch algorithm" + raise RuntimeError(msg) diff --git a/pyproject.toml b/pyproject.toml index fd558c0..c5296ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,6 @@ -[build-system] -requires = [ - "setuptools>=42", - "wheel", -] -build-backend = "setuptools.build_meta" - [project] name = "liboqs-python" -requires-python = ">=3.8" +requires-python = ">=3.9" version = "0.10.0" description = "Python bindings for liboqs, providing post-quantum public key cryptography algorithms" authors = [ @@ -15,13 +8,119 @@ authors = [ ] readme = "README.md" license = { file = "LICENSE" } -dependencies = [ - "nose2", +dependencies = [] + +[tool.uv] +package = true + +[project.optional-dependencies] +dev = [ + "isort==5.13.2", + "pre-commit==4.0.1", + "ruff==0.7.3", + "bandit==1.7.10", + "nose2==0.15.1", ] +lint = [ + "mypy==1.13.0", + "types-pytz==2024.2.0.20241003", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["oqs"] [project.urls] homepage = "https://github.com/open-quantum-safe/liboqs-python" repository = "https://github.com/open-quantum-safe/liboqs-python.git" -[tool.setuptools] -py-modules = [] +[tool.isort] +py_version = 39 +src_paths = ["oqs"] +line_length = 99 +multi_line_output = 3 +force_grid_wrap = 0 +include_trailing_comma = true +split_on_trailing_comma = false +single_line_exclusions = ["."] +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] +skip_gitignore = true +extend_skip = ["__pycache__"] +extend_skip_glob = [] + +[tool.ruff] +src = ["oqs"] +target-version = "py39" +line-length = 99 +exclude = [ + ".git", + ".mypy_cache", + ".ruff_cache", + "__pypackages__", + "__pycache__", + "*.pyi", + "venv", + ".venv", +] + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + "A003", + "ANN002", "ANN003", "ANN101", "ANN102", "ANN401", + "C901", + "D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107", "D203", "D205", "D212", + "ERA001", + "FA100", "FA102", + "FBT001", "FBT002", + "FIX002", + "I001", + "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR5501", + "PLW0120", + "RUF001", + "TD002", "TD003" +] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.mypy] +python_version = "3.9" +mypy_path = "." +packages = ["oqs"] +plugins = [] +allow_redefinition = true +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_defs = true +extra_checks = true +follow_imports_for_stubs = true +ignore_missing_imports = false +namespace_packages = true +no_implicit_optional = true +no_implicit_reexport = true +pretty = true +show_absolute_path = true +show_error_codes = true +show_error_context = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = true + +disable_error_code = [ + "no-redef", +] + +exclude = [ + "\\.?venv", + "\\.idea", + "\\.tests?", +] diff --git a/setup.py b/setup.py deleted file mode 100644 index ab322e4..0000000 --- a/setup.py +++ /dev/null @@ -1,7 +0,0 @@ -from setuptools import find_packages, setup - -setup( - packages=find_packages( - exclude=["tests", "docs", "examples"], - ), -) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_kem.py b/tests/test_kem.py index cd13640..348fb10 100644 --- a/tests/test_kem.py +++ b/tests/test_kem.py @@ -1,67 +1,72 @@ -import oqs import platform # to learn the OS we're on import random +import oqs + # KEMs for which unit testing is disabled -disabled_KEM_patterns = [] +disabled_KEM_patterns = [] # noqa: N816 if platform.system() == "Windows": - disabled_KEM_patterns = ["Classic-McEliece"] + disabled_KEM_patterns = ["Classic-McEliece"] # noqa: N816 -def test_correctness(): +def test_correctness() -> tuple[None, str]: for alg_name in oqs.get_enabled_kem_mechanisms(): if any(item in alg_name for item in disabled_KEM_patterns): continue yield check_correctness, alg_name -def check_correctness(alg_name): +def check_correctness(alg_name: str) -> None: with oqs.KeyEncapsulation(alg_name) as kem: public_key = kem.generate_keypair() ciphertext, shared_secret_server = kem.encap_secret(public_key) shared_secret_client = kem.decap_secret(ciphertext) - assert shared_secret_client == shared_secret_server + assert shared_secret_client == shared_secret_server # noqa: S101 -def test_wrong_ciphertext(): +def test_wrong_ciphertext() -> tuple[None, str]: for alg_name in oqs.get_enabled_kem_mechanisms(): if any(item in alg_name for item in disabled_KEM_patterns): continue yield check_wrong_ciphertext, alg_name -def check_wrong_ciphertext(alg_name): +def check_wrong_ciphertext(alg_name: str) -> None: with oqs.KeyEncapsulation(alg_name) as kem: public_key = kem.generate_keypair() ciphertext, shared_secret_server = kem.encap_secret(public_key) wrong_ciphertext = bytes(random.getrandbits(8) for _ in range(len(ciphertext))) shared_secret_client = kem.decap_secret(wrong_ciphertext) - assert shared_secret_client != shared_secret_server + assert shared_secret_client != shared_secret_server # noqa: S101 -def test_not_supported(): +def test_not_supported() -> None: try: - with oqs.KeyEncapsulation("bogus") as kem: - raise AssertionError("oqs.MechanismNotSupportedError was not raised.") + with oqs.KeyEncapsulation("bogus") as _kem: + msg = "oqs.MechanismNotSupportedError was not raised." + raise AssertionError(msg) # noqa: TRY301 except oqs.MechanismNotSupportedError: pass - except Exception as ex: - raise AssertionError("An unexpected exception was raised. " + ex) + except Exception as ex: # noqa: BLE001 + msg = f"An unexpected exception was raised. {ex}" + raise AssertionError(msg) # noqa: B904 -def test_not_enabled(): +def test_not_enabled() -> None: # TODO: test broken as the compiled lib determines which algorithms are supported and enabled for alg_name in oqs.get_supported_kem_mechanisms(): if alg_name not in oqs.get_enabled_kem_mechanisms(): # Found a non-enabled but supported alg try: - with oqs.KeyEncapsulation(alg_name) as kem: - raise AssertionError("oqs.MechanismNotEnabledError was not raised.") + with oqs.KeyEncapsulation(alg_name) as _kem: + msg = "oqs.MechanismNotEnabledError was not raised." + raise AssertionError(msg) # noqa: TRY301 except oqs.MechanismNotEnabledError: pass - except Exception as ex: - raise AssertionError("An unexpected exception was raised. " + ex) + except Exception as ex: # noqa: BLE001 + msg = f"An unexpected exception was raised. {ex}" + raise AssertionError(msg) # noqa: B904 if __name__ == "__main__": diff --git a/tests/test_sig.py b/tests/test_sig.py index 5053df4..0f0e5a1 100644 --- a/tests/test_sig.py +++ b/tests/test_sig.py @@ -1,7 +1,8 @@ -import oqs import platform # to learn the OS we're on import random +import oqs + # Sigs for which unit testing is disabled disabled_sig_patterns = [] @@ -9,91 +10,95 @@ disabled_sig_patterns = ["Rainbow-V"] -def test_correctness(): +def test_correctness() -> tuple[None, str]: for alg_name in oqs.get_enabled_sig_mechanisms(): if any(item in alg_name for item in disabled_sig_patterns): continue yield check_correctness, alg_name -def check_correctness(alg_name): +def check_correctness(alg_name: str) -> None: with oqs.Signature(alg_name) as sig: message = bytes(random.getrandbits(8) for _ in range(100)) public_key = sig.generate_keypair() signature = sig.sign(message) - assert sig.verify(message, signature, public_key) + assert sig.verify(message, signature, public_key) # noqa: S101 -def test_wrong_message(): +def test_wrong_message() -> tuple[None, str]: for alg_name in oqs.get_enabled_sig_mechanisms(): if any(item in alg_name for item in disabled_sig_patterns): continue yield check_wrong_message, alg_name -def check_wrong_message(alg_name): +def check_wrong_message(alg_name: str) -> None: with oqs.Signature(alg_name) as sig: message = bytes(random.getrandbits(8) for _ in range(100)) public_key = sig.generate_keypair() signature = sig.sign(message) wrong_message = bytes(random.getrandbits(8) for _ in range(len(message))) - assert not (sig.verify(wrong_message, signature, public_key)) + assert not (sig.verify(wrong_message, signature, public_key)) # noqa: S101 -def test_wrong_signature(): +def test_wrong_signature() -> tuple[None, str]: for alg_name in oqs.get_enabled_sig_mechanisms(): if any(item in alg_name for item in disabled_sig_patterns): continue yield check_wrong_signature, alg_name -def check_wrong_signature(alg_name): +def check_wrong_signature(alg_name: str) -> None: with oqs.Signature(alg_name) as sig: message = bytes(random.getrandbits(8) for _ in range(100)) public_key = sig.generate_keypair() signature = sig.sign(message) wrong_signature = bytes(random.getrandbits(8) for _ in range(len(signature))) - assert not (sig.verify(message, wrong_signature, public_key)) + assert not (sig.verify(message, wrong_signature, public_key)) # noqa: S101 -def test_wrong_public_key(): +def test_wrong_public_key() -> tuple[None, str]: for alg_name in oqs.get_enabled_sig_mechanisms(): if any(item in alg_name for item in disabled_sig_patterns): continue yield check_wrong_public_key, alg_name -def check_wrong_public_key(alg_name): +def check_wrong_public_key(alg_name: str) -> None: with oqs.Signature(alg_name) as sig: message = bytes(random.getrandbits(8) for _ in range(100)) public_key = sig.generate_keypair() signature = sig.sign(message) wrong_public_key = bytes(random.getrandbits(8) for _ in range(len(public_key))) - assert not (sig.verify(message, signature, wrong_public_key)) + assert not (sig.verify(message, signature, wrong_public_key)) # noqa: S101 -def test_not_supported(): +def test_not_supported() -> None: try: - with oqs.Signature("bogus") as sig: - raise AssertionError("oqs.MechanismNotSupportedError was not raised.") + with oqs.Signature("bogus") as _sig: + msg = "oqs.MechanismNotSupportedError was not raised." + raise AssertionError(msg) # noqa: TRY301 except oqs.MechanismNotSupportedError: pass - except Exception as ex: - raise AssertionError("An unexpected exception was raised. " + ex) + except Exception as ex: # noqa: BLE001 + msg = f"An unexpected exception was raised. {ex}" + raise AssertionError(msg) # noqa: B904 -def test_not_enabled(): +def test_not_enabled() -> None: # TODO: test broken as the compiled lib determines which algorithms are supported and enabled for alg_name in oqs.get_supported_sig_mechanisms(): if alg_name not in oqs.get_enabled_sig_mechanisms(): # Found a non-enabled but supported alg try: - with oqs.Signature(alg_name) as sig: - raise AssertionError("oqs.MechanismNotEnabledError was not raised.") + with oqs.Signature(alg_name) as _sig: + msg = "oqs.MechanismNotEnabledError was not raised." + raise AssertionError(msg) # noqa: TRY301 except oqs.MechanismNotEnabledError: pass - except Exception as ex: - raise AssertionError("An unexpected exception was raised. " + ex) + except Exception as ex: # noqa: BLE001 + msg = f"An unexpected exception was raised. {ex}" + raise AssertionError(msg) # noqa: B904 if __name__ == "__main__": From 1c4f1556c0542ed1b03ff667b3f19a5508dbf729 Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:11:25 +0200 Subject: [PATCH 02/10] Replace `pipe operator` to `Union`, to support py3.9 Bump deps Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- oqs/oqs.py | 43 ++++++++++++++++++++++------------------- pyproject.toml | 7 ++++--- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c6f2db..439abc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: "check-json" - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.7.3 + rev: v0.8.2 hooks: - id: ruff args: [ "--fix" ] diff --git a/oqs/oqs.py b/oqs/oqs.py index a244ed6..f331a22 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -20,7 +20,7 @@ import time import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, Union, cast if TYPE_CHECKING: from collections.abc import Sequence @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -def oqs_python_version() -> str | None: +def oqs_python_version() -> Union[str, None]: """liboqs-python version string.""" try: result = importlib.metadata.version("liboqs-python") @@ -57,7 +57,7 @@ def _countdown(seconds: int) -> None: def _load_shared_obj( name: str, - additional_searching_paths: Sequence[Path] | None = None, + additional_searching_paths: Union[Sequence[Path], None] = None, ) -> ct.CDLL: """Attempt to load shared library.""" paths: list[Path] = [] @@ -99,7 +99,10 @@ def _load_shared_obj( raise RuntimeError(msg) -def _install_liboqs(target_directory: Path, oqs_version_to_install: str | None = None) -> None: +def _install_liboqs( + target_directory: Path, + oqs_version_to_install: Union[str, None] = None, +) -> None: """Install liboqs version oqs_version (if None, installs latest at HEAD) in the target_directory.""" # noqa: E501 with tempfile.TemporaryDirectory() as tmpdirname: oqs_install_cmd = [ @@ -265,7 +268,7 @@ class KeyEncapsulation(ct.Structure): ("decaps_cb", ct.c_void_p), ] - def __init__(self, alg_name: str, secret_key: int | bytes | None = None) -> None: + def __init__(self, alg_name: str, secret_key: Union[int, bytes, None] = None) -> None: """ Create new KeyEncapsulation with the given algorithm. @@ -305,13 +308,13 @@ def __enter__(self: TKeyEncapsulation) -> TKeyEncapsulation: def __exit__( self, - ctx_type: type[BaseException] | None, - ctx_value: BaseException | None, - ctx_traceback: TracebackType | None, + ctx_type: Union[type[BaseException], None], + ctx_value: Union[BaseException, None], + ctx_traceback: Union[TracebackType, None], ) -> None: self.free() - def generate_keypair(self) -> bytes | int: + def generate_keypair(self) -> Union[bytes, int]: """ Generate a new keypair and returns the public key. @@ -330,7 +333,7 @@ def export_secret_key(self) -> bytes: """Export the secret key.""" return bytes(self.secret_key) - def encap_secret(self, public_key: int | bytes) -> tuple[bytes, bytes | int]: + def encap_secret(self, public_key: Union[int, bytes]) -> tuple[bytes, Union[bytes, int]]: """ Generate and encapsulates a secret using the provided public key. @@ -354,15 +357,15 @@ def encap_secret(self, public_key: int | bytes) -> tuple[bytes, bytes | int]: ) # TODO: What should it return? - # 1. tuple[bytes | int, bytes | int] - # 2. tuple[bytes, bytes | int] - # 3. tuple[bytes, bytes] | int + # 1. tuple[Union[bytes, int], Union[bytes, int]] + # 2. tuple[bytes, Union[bytes, int]] + # 3. Union[tuple[bytes, bytes], int] return ( bytes(cast(bytes, ciphertext)), bytes(cast(bytes, shared_secret)) if rv == OQS_SUCCESS else 0, ) - def decap_secret(self, ciphertext: int | bytes) -> bytes | int: + def decap_secret(self, ciphertext: Union[int, bytes]) -> Union[bytes, int]: """ Decapsulate the ciphertext and returns the secret. @@ -451,7 +454,7 @@ class Signature(ct.Structure): ("verify_cb", ct.c_void_p), ] - def __init__(self, alg_name: str, secret_key: int | bytes | None = None) -> None: + def __init__(self, alg_name: str, secret_key: Union[int, bytes, None] = None) -> None: """ Create new Signature with the given algorithm. @@ -488,13 +491,13 @@ def __enter__(self: TSignature) -> TSignature: def __exit__( self, - ctx_type: type[BaseException] | None, - ctx_value: BaseException | None, - ctx_traceback: TracebackType | None, + ctx_type: Union[type[BaseException], None], + ctx_value: Union[BaseException, None], + ctx_traceback: Union[TracebackType, None], ) -> None: self.free() - def generate_keypair(self) -> bytes | int: + def generate_keypair(self) -> Union[bytes, int]: """ Generate a new keypair and returns the public key. @@ -515,7 +518,7 @@ def export_secret_key(self) -> bytes: """Export the secret key.""" return bytes(self.secret_key) - def sign(self, message: bytes) -> bytes | int: + def sign(self, message: bytes) -> Union[bytes, int]: """ Signs the provided message and returns the signature. diff --git a/pyproject.toml b/pyproject.toml index c5296ee..634fd6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ package = true dev = [ "isort==5.13.2", "pre-commit==4.0.1", - "ruff==0.7.3", - "bandit==1.7.10", + "ruff==0.8.2", + "bandit==1.8.0", "nose2==0.15.1", ] lint = [ @@ -81,7 +81,8 @@ ignore = [ "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR5501", "PLW0120", "RUF001", - "TD002", "TD003" + "TD002", "TD003", + "U007", ] [tool.ruff.format] From e9271f2dccd674259c2e59459e59984f1137c707 Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Sun, 26 Jan 2025 20:34:14 +0200 Subject: [PATCH 03/10] Resolve conflicts Bump deps Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- .github/workflows/python_detailed.yml | 2 +- .github/workflows/python_simplified.yml | 2 +- .pre-commit-config.yaml | 2 +- Makefile | 7 +- examples/kem.py | 3 +- examples/rand.py | 1 + examples/sig.py | 1 + oqs/__init__.py | 4 +- oqs/oqs.py | 138 +++++++++++++----------- pyproject.toml | 18 ++-- tests/test_kem.py | 51 +++++---- tests/test_sig.py | 50 +++++---- 12 files changed, 160 insertions(+), 119 deletions(-) diff --git a/.github/workflows/python_detailed.yml b/.github/workflows/python_detailed.yml index bb78951..db2b8bd 100644 --- a/.github/workflows/python_detailed.yml +++ b/.github/workflows/python_detailed.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v3 + uses: astral-sh/setup-uv@v5 with: version: "latest" enable-cache: true diff --git a/.github/workflows/python_simplified.yml b/.github/workflows/python_simplified.yml index d41c4fe..c1fed63 100644 --- a/.github/workflows/python_simplified.yml +++ b/.github/workflows/python_simplified.yml @@ -22,7 +22,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v3 + uses: astral-sh/setup-uv@v5 with: version: "latest" enable-cache: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 439abc6..f49bc59 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: "check-json" - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.8.2 + rev: v0.9.3 hooks: - id: ruff args: [ "--fix" ] diff --git a/Makefile b/Makefile index 8a2409f..02b6f6a 100644 --- a/Makefile +++ b/Makefile @@ -2,11 +2,6 @@ src-dir = oqs tests-dir = tests examples-dir = examples -.PHONY pull: -pull: - git pull origin master - git submodule update --init --recursive - .PHONY lint: lint: echo "Running ruff..." @@ -26,7 +21,7 @@ format: .PHONE mypy: mypy: echo "Running MyPy..." - uv run mypy --config-file pyproject.toml + uv run mypy --config-file pyproject.toml $(src-dir) .PHONY outdated: outdated: diff --git a/examples/kem.py b/examples/kem.py index 8c27da4..f689213 100644 --- a/examples/kem.py +++ b/examples/kem.py @@ -1,4 +1,5 @@ # Key encapsulation Python example + import logging from pprint import pformat @@ -16,7 +17,7 @@ kemalg = "ML-KEM-512" with oqs.KeyEncapsulation(kemalg) as client: with oqs.KeyEncapsulation(kemalg) as server: - logger.info("Client details: %s", pformat(client.details)) + logger.info("Key encapsulation details: %s", pformat(client.details)) # Client generates its keypair public_key_client = client.generate_keypair() diff --git a/examples/rand.py b/examples/rand.py index 5af394f..118fdfd 100644 --- a/examples/rand.py +++ b/examples/rand.py @@ -1,4 +1,5 @@ # Various RNGs Python example + import logging import platform # to learn the OS we're on diff --git a/examples/sig.py b/examples/sig.py index 723aaa2..0e70fbe 100644 --- a/examples/sig.py +++ b/examples/sig.py @@ -1,4 +1,5 @@ # Signature Python example + import logging from pprint import pformat diff --git a/oqs/__init__.py b/oqs/__init__.py index 6dab0c3..f48ffe3 100644 --- a/oqs/__init__.py +++ b/oqs/__init__.py @@ -17,11 +17,11 @@ ) __all__ = ( + "OQS_SUCCESS", + "OQS_VERSION", "KeyEncapsulation", "MechanismNotEnabledError", "MechanismNotSupportedError", - "OQS_SUCCESS", - "OQS_VERSION", "Signature", "get_enabled_kem_mechanisms", "get_enabled_sig_mechanisms", diff --git a/oqs/oqs.py b/oqs/oqs.py index 0bba1b5..5363374 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -15,11 +15,12 @@ import logging import platform # to learn the OS we're on import subprocess -import sys import tempfile # to install liboqs on demand import time import warnings +from os import environ from pathlib import Path +from sys import stdout from typing import TYPE_CHECKING, Any, ClassVar, Final, TypeVar, Union, cast if TYPE_CHECKING: @@ -50,7 +51,7 @@ def oqs_python_version() -> Union[str, None]: def _countdown(seconds: int) -> None: while seconds > 0: logger.info("Installing in %s seconds...", seconds) - sys.stdout.flush() + stdout.flush() seconds -= 1 time.sleep(1) @@ -78,7 +79,6 @@ def _load_shared_obj( # os.environ["LD_LIBRARY_PATH"] += os.path.abspath(path) # Search typical locations - if found_lib := ctu.find_library(name): paths.insert(0, Path(found_lib)) @@ -91,7 +91,6 @@ def _load_shared_obj( lib: ct.CDLL = dll.LoadLibrary(str(path)) except OSError: pass - else: return lib @@ -158,14 +157,14 @@ def _install_liboqs( if _retcode != 0: logger.exception("Error installing liboqs.") - sys.exit(1) + raise SystemExit(1) logger.info("Done installing liboqs") def _load_liboqs() -> ct.CDLL: - if "OQS_INSTALL_PATH" in os.environ: - oqs_install_dir = os.path.abspath(os.environ["OQS_INSTALL_PATH"]) + if "OQS_INSTALL_PATH" in environ: + oqs_install_dir = Path(environ["OQS_INSTALL_PATH"]) else: home_dir = Path.home() oqs_install_dir = home_dir / "_oqs" @@ -175,15 +174,14 @@ def _load_liboqs() -> ct.CDLL: else oqs_install_dir / "lib" # $HOME/_oqs/lib ) oqs_lib64_dir = ( - os.path.abspath(oqs_install_dir + os.path.sep + "bin") # $HOME/_oqs/bin + oqs_install_dir / "bin" # $HOME/_oqs/bin if platform.system() == "Windows" - else os.path.abspath( - oqs_install_dir + os.path.sep + "lib64" - ) # $HOME/_oqs/lib64 + else oqs_install_dir / "lib64" # $HOME/_oqs/lib64 ) try: liboqs = _load_shared_obj( - name="oqs", additional_searching_paths=[oqs_lib_dir, oqs_lib64_dir] + name="oqs", + additional_searching_paths=[oqs_lib_dir, oqs_lib64_dir], ) assert liboqs # noqa: S101 except RuntimeError: @@ -197,7 +195,8 @@ def _load_liboqs() -> ct.CDLL: ) assert liboqs # noqa: S101 except RuntimeError: - sys.exit("Could not load liboqs shared library") + msg = "Could not load liboqs shared library" + raise SystemExit(msg) from None return liboqs @@ -329,13 +328,13 @@ def __enter__(self: TKeyEncapsulation) -> TKeyEncapsulation: def __exit__( self, - ctx_type: Union[type[BaseException], None], - ctx_value: Union[BaseException, None], - ctx_traceback: Union[TracebackType, None], + exc_type: Union[type[BaseException], None], + exc_value: Union[BaseException, None], + traceback: Union[TracebackType, None], ) -> None: self.free() - def generate_keypair(self) -> Union[bytes, int]: + def generate_keypair(self) -> bytes: """ Generate a new keypair and returns the public key. @@ -350,14 +349,14 @@ def generate_keypair(self) -> Union[bytes, int]: ) if rv == OQS_SUCCESS: return bytes(public_key) - else: - raise RuntimeError("Can not generate keypair") + msg = "Can not generate keypair" + raise RuntimeError(msg) def export_secret_key(self) -> bytes: """Export the secret key.""" return bytes(self.secret_key) - def encap_secret(self, public_key: Union[int, bytes]) -> tuple[bytes, Union[bytes, int]]: + def encap_secret(self, public_key: Union[int, bytes]) -> tuple[bytes, bytes]: """ Generate and encapsulates a secret using the provided public key. @@ -374,32 +373,39 @@ def encap_secret(self, public_key: Union[int, bytes]) -> tuple[bytes, Union[byte self._kem.contents.length_shared_secret, ) rv = native().OQS_KEM_encaps( - self._kem, ct.byref(ciphertext), ct.byref(shared_secret), c_public_key + self._kem, + ct.byref(ciphertext), + ct.byref(shared_secret), + c_public_key, ) if rv == OQS_SUCCESS: return bytes(ciphertext), bytes(shared_secret) - else: - raise RuntimeError("Can not encapsulate secret") + msg = "Can not encapsulate secret" + raise RuntimeError(msg) - def decap_secret(self, ciphertext: Union[int, bytes]) -> Union[bytes, int]: + def decap_secret(self, ciphertext: Union[int, bytes]) -> bytes: """ Decapsulate the ciphertext and returns the secret. :param ciphertext: the ciphertext received from the peer. """ c_ciphertext = ct.create_string_buffer( - ciphertext, self._kem.contents.length_ciphertext + ciphertext, + self._kem.contents.length_ciphertext, ) shared_secret: ct.Array[ct.c_char] = ct.create_string_buffer( self._kem.contents.length_shared_secret, ) rv = native().OQS_KEM_decaps( - self._kem, ct.byref(shared_secret), c_ciphertext, self.secret_key + self._kem, + ct.byref(shared_secret), + c_ciphertext, + self.secret_key, ) if rv == OQS_SUCCESS: return bytes(shared_secret) - else: - raise RuntimeError("Can not decapsulate secret") + msg = "Can not decapsulate secret" + raise RuntimeError(msg) def free(self) -> None: """Releases the native resources.""" @@ -428,16 +434,16 @@ def is_kem_enabled(alg_name: str) -> bool: _KEM_alg_ids = [native().OQS_KEM_alg_identifier(i) for i in range(native().OQS_KEM_alg_count())] -_supported_KEMs: list[str] = [i.decode() for i in _KEM_alg_ids] # noqa: N816 -_enabled_KEMs: list[str] = [i for i in _supported_KEMs if is_kem_enabled(i)] # noqa: N816 +_supported_KEMs: tuple[str, ...] = tuple([i.decode() for i in _KEM_alg_ids]) # noqa: N816 +_enabled_KEMs: tuple[str, ...] = tuple([i for i in _supported_KEMs if is_kem_enabled(i)]) # noqa: N816 -def get_enabled_kem_mechanisms() -> list[str]: +def get_enabled_kem_mechanisms() -> tuple[str, ...]: """Return the list of enabled KEM mechanisms.""" return _enabled_KEMs -def get_supported_kem_mechanisms() -> list[str]: +def get_supported_kem_mechanisms() -> tuple[str, ...]: """Return the list of supported KEM mechanisms.""" return _supported_KEMs @@ -456,7 +462,7 @@ class Signature(ct.Structure): free | OQS_SIG_free """ - _fields_ = [ + _fields_: ClassVar[list[tuple[str, Any]]] = [ ("method_name", ct.c_char_p), ("alg_version", ct.c_char_p), ("claimed_nist_level", ct.c_ubyte), @@ -517,13 +523,13 @@ def __enter__(self: TSignature) -> TSignature: def __exit__( self, - ctx_type: Union[type[BaseException], None], - ctx_value: Union[BaseException, None], - ctx_traceback: Union[TracebackType, None], + exc_type: Union[type[BaseException], None], + exc_value: Union[BaseException, None], + traceback: Union[TracebackType, None], ) -> None: self.free() - def generate_keypair(self) -> Union[bytes, int]: + def generate_keypair(self) -> bytes: """ Generate a new keypair and returns the public key. @@ -540,14 +546,14 @@ def generate_keypair(self) -> Union[bytes, int]: ) if rv == OQS_SUCCESS: return bytes(public_key) - else: - raise RuntimeError("Can not generate keypair") + msg = "Can not generate keypair" + raise RuntimeError(msg) def export_secret_key(self) -> bytes: """Export the secret key.""" return bytes(self.secret_key) - def sign(self, message: bytes) -> Union[bytes, int]: + def sign(self, message: bytes) -> bytes: """ Signs the provided message and returns the signature. @@ -570,9 +576,9 @@ def sign(self, message: bytes) -> Union[bytes, int]: self.secret_key, ) if rv == OQS_SUCCESS: - return bytes(c_signature[: c_signature_len.value]) - else: - raise RuntimeError("Can not sign message") + return bytes(cast(bytes, c_signature[: c_signature_len.value])) + msg = "Can not sign message" + raise RuntimeError(msg) def verify(self, message: bytes, signature: bytes, public_key: bytes) -> bool: """ @@ -588,7 +594,8 @@ def verify(self, message: bytes, signature: bytes, public_key: bytes) -> bool: c_signature = ct.create_string_buffer(signature, len(signature)) c_signature_len = ct.c_size_t(len(c_signature)) c_public_key = ct.create_string_buffer( - public_key, self._sig.contents.length_public_key + public_key, + self._sig.contents.length_public_key, ) rv = native().OQS_SIG_verify( @@ -599,24 +606,25 @@ def verify(self, message: bytes, signature: bytes, public_key: bytes) -> bool: c_signature_len, c_public_key, ) - return True if rv == OQS_SUCCESS else False + return rv == OQS_SUCCESS - def sign_with_ctx_str(self, message, context): + def sign_with_ctx_str(self, message: bytes, context: bytes) -> bytes: """ - Signs the provided message with context string and returns the signature. + Sign the provided message with context string and returns the signature. :param context: the context string. :param message: the message to sign. """ if context and not self._sig.contents.sig_with_ctx_support: - raise RuntimeError("Signing with context string not supported") + msg = "Signing with context string not supported" + raise RuntimeError(msg) # Provide length to avoid extra null char c_message = ct.create_string_buffer(message, len(message)) c_message_len = ct.c_size_t(len(c_message)) if len(context) == 0: c_context = None - c_context_len = 0 + c_context_len = ct.c_size_t(0) else: c_context = ct.create_string_buffer(context, len(context)) c_context_len = ct.c_size_t(len(c_context)) @@ -635,13 +643,19 @@ def sign_with_ctx_str(self, message, context): self.secret_key, ) if rv == OQS_SUCCESS: - return bytes(c_signature[: c_signature_len.value]) - else: - raise RuntimeError("Can not sign message with context string") + return bytes(cast(bytes, c_signature[: c_signature_len.value])) + msg = "Can not sign message with context string" + raise RuntimeError(msg) - def verify_with_ctx_str(self, message, signature, context, public_key): + def verify_with_ctx_str( + self, + message: bytes, + signature: bytes, + context: bytes, + public_key: bytes, + ) -> bool: """ - Verifies the provided signature on the message with context string; returns True if valid. + Verify the provided signature on the message with context string; returns True if valid. :param message: the signed message. :param signature: the signature on the message. @@ -649,7 +663,8 @@ def verify_with_ctx_str(self, message, signature, context, public_key): :param public_key: the signer's public key. """ if context and not self._sig.contents.sig_with_ctx_support: - raise RuntimeError("Verifying with context string not supported") + msg = "Verifying with context string not supported" + raise RuntimeError(msg) # Provide length to avoid extra null char c_message = ct.create_string_buffer(message, len(message)) @@ -658,12 +673,13 @@ def verify_with_ctx_str(self, message, signature, context, public_key): c_signature_len = ct.c_size_t(len(c_signature)) if len(context) == 0: c_context = None - c_context_len = 0 + c_context_len = ct.c_size_t(0) else: c_context = ct.create_string_buffer(context, len(context)) c_context_len = ct.c_size_t(len(c_context)) c_public_key = ct.create_string_buffer( - public_key, self._sig.contents.length_public_key + public_key, + self._sig.contents.length_public_key, ) rv = native().OQS_SIG_verify_with_ctx_str( @@ -705,15 +721,15 @@ def is_sig_enabled(alg_name: str) -> bool: _sig_alg_ids = [native().OQS_SIG_alg_identifier(i) for i in range(native().OQS_SIG_alg_count())] -_supported_sigs = [i.decode() for i in _sig_alg_ids] -_enabled_sigs = [i for i in _supported_sigs if is_sig_enabled(i)] +_supported_sigs: tuple[str, ...] = tuple([i.decode() for i in _sig_alg_ids]) +_enabled_sigs: tuple[str, ...] = tuple([i for i in _supported_sigs if is_sig_enabled(i)]) -def get_enabled_sig_mechanisms() -> list[str]: +def get_enabled_sig_mechanisms() -> tuple[str, ...]: """Return the list of enabled signature mechanisms.""" return _enabled_sigs -def get_supported_sig_mechanisms() -> list[str]: +def get_supported_sig_mechanisms() -> tuple[str, ...]: """Return the list of supported signature mechanisms.""" return _supported_sigs diff --git a/pyproject.toml b/pyproject.toml index ab9d89e..9341803 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ { name = "Open Quantum Safe project", email = "contact@openquantumsafe.org" }, ] readme = "README.md" -license = { file = "LICENSE" } +license = { file = "LICENSE.txt" } dependencies = [] [tool.uv] @@ -16,14 +16,13 @@ package = true [project.optional-dependencies] dev = [ "isort==5.13.2", - "pre-commit==4.0.1", - "ruff==0.8.2", - "bandit==1.8.0", + "pre-commit==4.1.0", + "ruff==0.9.3", "nose2==0.15.1", ] lint = [ - "mypy==1.13.0", - "types-pytz==2024.2.0.20241003", + "mypy==1.14.1", + "types-pytz==2024.2.0.20241221", ] [build-system] @@ -70,7 +69,7 @@ exclude = [ select = ["ALL"] ignore = [ "A003", - "ANN002", "ANN003", "ANN101", "ANN102", "ANN401", + "ANN002", "ANN003", "ANN401", "C901", "D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107", "D203", "D205", "D212", "ERA001", @@ -78,11 +77,11 @@ ignore = [ "FBT001", "FBT002", "FIX002", "I001", - "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR5501", + "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004", "PLR5501", "PLW0120", "RUF001", "TD002", "TD003", - "U007", + "UP007", ] [tool.ruff.format] @@ -103,6 +102,7 @@ disallow_incomplete_defs = true disallow_untyped_calls = true disallow_untyped_defs = true extra_checks = true +follow_imports = "normal" follow_imports_for_stubs = true ignore_missing_imports = false namespace_packages = true diff --git a/tests/test_kem.py b/tests/test_kem.py index f692fcc..fb0e482 100644 --- a/tests/test_kem.py +++ b/tests/test_kem.py @@ -7,7 +7,7 @@ disabled_KEM_patterns = [] # noqa: N816 if platform.system() == "Windows": - disabled_KEM_patterns = [""] + disabled_KEM_patterns = [""] # noqa: N816 def test_correctness() -> tuple[None, str]: @@ -39,21 +39,26 @@ def check_wrong_ciphertext(alg_name: str) -> None: wrong_ciphertext = bytes(random.getrandbits(8) for _ in range(len(ciphertext))) try: shared_secret_client = kem.decap_secret(wrong_ciphertext) - assert shared_secret_client != shared_secret_server + assert shared_secret_client != shared_secret_server # noqa: S101 except RuntimeError: pass except Exception as ex: - raise AssertionError(f"An unexpected exception was raised: {ex}") + msg = f"An unexpected exception was raised: {ex}" + raise AssertionError(msg) from ex def test_not_supported() -> None: try: with oqs.KeyEncapsulation("unsupported_sig"): - raise AssertionError("oqs.MechanismNotSupportedError was not raised.") + pass except oqs.MechanismNotSupportedError: pass except Exception as ex: - raise AssertionError(f"An unexpected exception was raised {ex}") + msg = f"An unexpected exception was raised {ex}" + raise AssertionError(msg) from ex + else: + msg = "oqs.MechanismNotSupportedError was not raised." + raise AssertionError(msg) def test_not_enabled() -> None: @@ -62,30 +67,41 @@ def test_not_enabled() -> None: # Found a non-enabled but supported alg try: with oqs.KeyEncapsulation(alg_name): - raise AssertionError("oqs.MechanismNotEnabledError was not raised.") + pass except oqs.MechanismNotEnabledError: pass except Exception as ex: - raise AssertionError(f"An unexpected exception was raised: {ex}") + msg = f"An unexpected exception was raised: {ex}" + raise AssertionError(msg) from ex + else: + msg = "oqs.MechanismNotEnabledError was not raised." + raise AssertionError(msg) -def test_python_attributes(): +def test_python_attributes() -> None: for alg_name in oqs.get_enabled_kem_mechanisms(): with oqs.KeyEncapsulation(alg_name) as kem: if kem.method_name.decode() != alg_name: - raise AssertionError("Incorrect oqs.KeyEncapsulation.method_name") + msg = "Incorrect oqs.KeyEncapsulation.method_name" + raise AssertionError(msg) if kem.alg_version is None: - raise AssertionError("Undefined oqs.KeyEncapsulation.alg_version") + msg = "Undefined oqs.KeyEncapsulation.alg_version" + raise AssertionError(msg) if not 1 <= kem.claimed_nist_level <= 5: - raise AssertionError("Invalid oqs.KeyEncapsulation.claimed_nist_level") + msg = "Invalid oqs.KeyEncapsulation.claimed_nist_level" + raise AssertionError(msg) if kem.length_public_key == 0: - raise AssertionError("Incorrect oqs.KeyEncapsulation.length_public_key") + msg = "Incorrect oqs.KeyEncapsulation.length_public_key" + raise AssertionError(msg) if kem.length_secret_key == 0: - raise AssertionError("Incorrect oqs.KeyEncapsulation.length_secret_key") + msg = "Incorrect oqs.KeyEncapsulation.length_secret_key" + raise AssertionError(msg) if kem.length_ciphertext == 0: - raise AssertionError("Incorrect oqs.KeyEncapsulation.length_signature") + msg = "Incorrect oqs.KeyEncapsulation.length_signature" + raise AssertionError(msg) if kem.length_shared_secret == 0: - raise AssertionError("Incorrect oqs.KeyEncapsulation.length_shared_secret") + msg = "Incorrect oqs.KeyEncapsulation.length_shared_secret" + raise AssertionError(msg) if __name__ == "__main__": @@ -94,6 +110,5 @@ def test_python_attributes(): nose2.main() except ImportError: - raise RuntimeError( - "nose2 module not found. Please install it with 'pip install nose2'." - ) + msg_ = "nose2 module not found. Please install it with 'pip install nose2'." + raise RuntimeError(msg_) from None diff --git a/tests/test_sig.py b/tests/test_sig.py index b198086..b579e1a 100644 --- a/tests/test_sig.py +++ b/tests/test_sig.py @@ -2,7 +2,6 @@ import random import oqs - from oqs.oqs import Signature # Sigs for which unit testing is disabled @@ -19,7 +18,7 @@ def test_correctness() -> tuple[None, str]: yield check_correctness, alg_name -def test_correctness_with_ctx_str(): +def test_correctness_with_ctx_str() -> tuple[None, str]: for alg_name in oqs.get_enabled_sig_mechanisms(): if not Signature(alg_name).details["sig_with_ctx_support"]: continue @@ -36,13 +35,13 @@ def check_correctness(alg_name: str) -> None: assert sig.verify(message, signature, public_key) # noqa: S101 -def check_correctness_with_ctx_str(alg_name): +def check_correctness_with_ctx_str(alg_name: str) -> None: with oqs.Signature(alg_name) as sig: message = bytes(random.getrandbits(8) for _ in range(100)) - context = "some context".encode() + context = b"some context" public_key = sig.generate_keypair() signature = sig.sign_with_ctx_str(message, context) - assert sig.verify_with_ctx_str(message, signature, context, public_key) + assert sig.verify_with_ctx_str(message, signature, context, public_key) # noqa: S101 def test_wrong_message() -> tuple[None, str]: @@ -96,11 +95,15 @@ def check_wrong_public_key(alg_name: str) -> None: def test_not_supported() -> None: try: with oqs.Signature("unsupported_sig"): - raise AssertionError("oqs.MechanismNotSupportedError was not raised.") + pass except oqs.MechanismNotSupportedError: pass except Exception as ex: - raise AssertionError(f"An unexpected exception was raised: {ex}") + msg = f"An unexpected exception was raised: {ex}" + raise AssertionError(msg) from ex + else: + msg = "oqs.MechanismNotSupportedError was not raised." + raise AssertionError(msg) def test_not_enabled() -> None: @@ -109,28 +112,38 @@ def test_not_enabled() -> None: # Found a non-enabled but supported alg try: with oqs.Signature(alg_name): - raise AssertionError("oqs.MechanismNotEnabledError was not raised.") + pass except oqs.MechanismNotEnabledError: pass except Exception as ex: - raise AssertionError(f"An unexpected exception was raised: {ex}") + msg = f"An unexpected exception was raised: {ex}" + raise AssertionError(msg) from ex + else: + msg = "oqs.MechanismNotEnabledError was not raised." + raise AssertionError(msg) -def test_python_attributes(): +def test_python_attributes() -> None: for alg_name in oqs.get_enabled_sig_mechanisms(): with oqs.Signature(alg_name) as sig: if sig.method_name.decode() != alg_name: - raise AssertionError("Incorrect oqs.Signature.method_name") + msg = "Incorrect oqs.Signature.method_name" + raise AssertionError(msg) if sig.alg_version is None: - raise AssertionError("Undefined oqs.Signature.alg_version") + msg = "Undefined oqs.Signature.alg_version" + raise AssertionError(msg) if not 1 <= sig.claimed_nist_level <= 5: - raise AssertionError("Invalid oqs.Signature.claimed_nist_level") + msg = "Invalid oqs.Signature.claimed_nist_level" + raise AssertionError(msg) if sig.length_public_key == 0: - raise AssertionError("Incorrect oqs.Signature.length_public_key") + msg = "Incorrect oqs.Signature.length_public_key" + raise AssertionError(msg) if sig.length_secret_key == 0: - raise AssertionError("Incorrect oqs.Signature.length_secret_key") + msg = "Incorrect oqs.Signature.length_secret_key" + raise AssertionError(msg) if sig.length_signature == 0: - raise AssertionError("Incorrect oqs.Signature.length_signature") + msg = "Incorrect oqs.Signature.length_signature" + raise AssertionError(msg) if __name__ == "__main__": @@ -139,6 +152,5 @@ def test_python_attributes(): nose2.main() except ImportError: - raise RuntimeError( - "nose2 module not found. Please install it with 'pip install nose2'." - ) + msg_ = "nose2 module not found. Please install it with 'pip install nose2'." + raise RuntimeError(msg_) from None From 639fce80106450e8dae89c81c44524dbd246d7dd Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Sun, 26 Jan 2025 20:50:15 +0200 Subject: [PATCH 04/10] Fix `_fields_` in `Signature` Add stream handler to logger Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- oqs/oqs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/oqs/oqs.py b/oqs/oqs.py index 5363374..7409465 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -31,6 +31,8 @@ TSignature = TypeVar("TSignature", bound="Signature") logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler(stdout)) def oqs_python_version() -> Union[str, None]: @@ -467,6 +469,7 @@ class Signature(ct.Structure): ("alg_version", ct.c_char_p), ("claimed_nist_level", ct.c_ubyte), ("euf_cma", ct.c_ubyte), + ("sig_with_ctx_support", ct.c_ubyte), ("length_public_key", ct.c_size_t), ("length_secret_key", ct.c_size_t), ("length_signature", ct.c_size_t), From 44ef24a13e8f33fe7ed366e6a09e6b30e9a5efb4 Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Sun, 26 Jan 2025 21:00:42 +0200 Subject: [PATCH 05/10] Add stream handler to logger in rand.py and kem.py Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- examples/kem.py | 2 +- examples/rand.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/kem.py b/examples/kem.py index f689213..355eef4 100644 --- a/examples/kem.py +++ b/examples/kem.py @@ -5,9 +5,9 @@ import oqs -logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) logger.info("liboqs version: %s", oqs.oqs_version()) logger.info("liboqs-python version: %s", oqs.oqs_python_version()) diff --git a/examples/rand.py b/examples/rand.py index 118fdfd..1b1a5f8 100644 --- a/examples/rand.py +++ b/examples/rand.py @@ -6,9 +6,9 @@ import oqs.rand as oqsrand # must be explicitly imported from oqs import oqs_python_version, oqs_version -logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) logger.info("liboqs version: %s", oqs_version()) logger.info("liboqs-python version: %s", oqs_python_version()) From a01fb50a6ba07fe91186bb6f1a262b403aad7d3d Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Sun, 26 Jan 2025 21:04:08 +0200 Subject: [PATCH 06/10] Add stream handler to logger in sig.py Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- examples/sig.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sig.py b/examples/sig.py index 0e70fbe..ad47ee4 100644 --- a/examples/sig.py +++ b/examples/sig.py @@ -5,9 +5,9 @@ import oqs -logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) logger.info("liboqs version: %s", oqs.oqs_version()) logger.info("liboqs-python version: %s", oqs.oqs_python_version()) From 59c571b29e8c23f83462b1d3288076c0e61a3ce5 Mon Sep 17 00:00:00 2001 From: Vlad Gheorghiu Date: Wed, 29 Jan 2025 15:59:55 -0500 Subject: [PATCH 07/10] Updated examples --- examples/kem.py | 10 +++++++--- examples/rand.py | 3 ++- examples/sig.py | 7 ++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/kem.py b/examples/kem.py index 355eef4..799f68b 100644 --- a/examples/kem.py +++ b/examples/kem.py @@ -2,22 +2,26 @@ import logging from pprint import pformat +from sys import stdout import oqs logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -logger.addHandler(logging.StreamHandler()) +logger.addHandler(logging.StreamHandler(stdout)) logger.info("liboqs version: %s", oqs.oqs_version()) logger.info("liboqs-python version: %s", oqs.oqs_python_version()) -logger.info("Enabled KEM mechanisms: %s", pformat(oqs.get_enabled_kem_mechanisms(), compact=True)) +logger.info( + "Enabled KEM mechanisms:\n%s", + pformat(oqs.get_enabled_kem_mechanisms(), compact=True), +) # Create client and server with sample KEM mechanisms kemalg = "ML-KEM-512" with oqs.KeyEncapsulation(kemalg) as client: with oqs.KeyEncapsulation(kemalg) as server: - logger.info("Key encapsulation details: %s", pformat(client.details)) + logger.info("Key encapsulation details:\n%s", pformat(client.details)) # Client generates its keypair public_key_client = client.generate_keypair() diff --git a/examples/rand.py b/examples/rand.py index 1b1a5f8..6f3d9f5 100644 --- a/examples/rand.py +++ b/examples/rand.py @@ -2,13 +2,14 @@ import logging import platform # to learn the OS we're on +from sys import stdout import oqs.rand as oqsrand # must be explicitly imported from oqs import oqs_python_version, oqs_version logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -logger.addHandler(logging.StreamHandler()) +logger.addHandler(logging.StreamHandler(stdout)) logger.info("liboqs version: %s", oqs_version()) logger.info("liboqs-python version: %s", oqs_python_version()) diff --git a/examples/sig.py b/examples/sig.py index ad47ee4..3a432f4 100644 --- a/examples/sig.py +++ b/examples/sig.py @@ -2,17 +2,18 @@ import logging from pprint import pformat +from sys import stdout import oqs logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -logger.addHandler(logging.StreamHandler()) +logger.addHandler(logging.StreamHandler(stdout)) logger.info("liboqs version: %s", oqs.oqs_version()) logger.info("liboqs-python version: %s", oqs.oqs_python_version()) logger.info( - "Enabled signature mechanisms: %s", + "Enabled signature mechanisms:\n%s", pformat(oqs.get_enabled_sig_mechanisms(), compact=True), ) @@ -21,7 +22,7 @@ # Create signer and verifier with sample signature mechanisms sigalg = "ML-DSA-44" with oqs.Signature(sigalg) as signer, oqs.Signature(sigalg) as verifier: - logger.info("Signature details: %s", pformat(signer.details)) + logger.info("Signature details:\n%s", pformat(signer.details)) # Signer generates its keypair signer_public_key = signer.generate_keypair() From 5dd05577623d95dddecb3755e5b956d3e9731b40 Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Sat, 1 Feb 2025 13:11:09 +0200 Subject: [PATCH 08/10] Bump `ruff`, `0.9.3` -> `0.9.4` Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f49bc59..393527b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: "check-json" - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.9.3 + rev: v0.9.4 hooks: - id: ruff args: [ "--fix" ] diff --git a/pyproject.toml b/pyproject.toml index 9341803..6011381 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ package = true dev = [ "isort==5.13.2", "pre-commit==4.1.0", - "ruff==0.9.3", + "ruff==0.9.4", "nose2==0.15.1", ] lint = [ From 8d0465b9c1435abe8628957c2d4dadbc307bf484 Mon Sep 17 00:00:00 2001 From: andrew000 <11490628+andrew000@users.noreply.github.com> Date: Sat, 1 Feb 2025 13:32:16 +0200 Subject: [PATCH 09/10] Change type hint of `_fields_` from `list` to `Sequence` Signed-off-by: andrew000 <11490628+andrew000@users.noreply.github.com> --- oqs/oqs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oqs/oqs.py b/oqs/oqs.py index 7409465..1dbcf29 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -267,7 +267,7 @@ class KeyEncapsulation(ct.Structure): free | OQS_KEM_free """ - _fields_: ClassVar[list[tuple[str, Any]]] = [ + _fields_: ClassVar[Sequence[tuple[str, Any]]] = [ ("method_name", ct.c_char_p), ("alg_version", ct.c_char_p), ("claimed_nist_level", ct.c_ubyte), @@ -464,7 +464,7 @@ class Signature(ct.Structure): free | OQS_SIG_free """ - _fields_: ClassVar[list[tuple[str, Any]]] = [ + _fields_: ClassVar[Sequence[tuple[str, Any]]] = [ ("method_name", ct.c_char_p), ("alg_version", ct.c_char_p), ("claimed_nist_level", ct.c_ubyte), From 5a872e59671095bf1d343fd9c2edefe2015b5978 Mon Sep 17 00:00:00 2001 From: Vlad Gheorghiu Date: Sun, 2 Feb 2025 15:46:52 -0500 Subject: [PATCH 10/10] Added COM812 to tools.ruff.lint.ignore Signed-off-by: Vlad Gheorghiu --- Makefile | 9 +++++++++ pyproject.toml | 1 + 2 files changed, 10 insertions(+) diff --git a/Makefile b/Makefile index 02b6f6a..9145e76 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,12 @@ +# Code checker/formatter +# +# Pre-requisites +# +# isort +# mypy +# ruff +# uv + src-dir = oqs tests-dir = tests examples-dir = examples diff --git a/pyproject.toml b/pyproject.toml index 6011381..cdeda47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ ignore = [ "A003", "ANN002", "ANN003", "ANN401", "C901", + "COM812", "D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107", "D203", "D205", "D212", "ERA001", "FA100", "FA102",