diff --git a/keyutils/__init__.py b/keyutils/__init__.py index f9be844..f74b279 100644 --- a/keyutils/__init__.py +++ b/keyutils/__init__.py @@ -17,7 +17,7 @@ from __future__ import absolute_import -from typing import Union +from typing import Union, Optional from . import _keyutils @@ -46,7 +46,7 @@ def _handle_keyerror(err: Exception): def add_key(desc, value, keyring, keyType=b"user"): return _keyutils.add_key(keyType, desc, value, keyring) -def add_ring(desc, keyring): +def add_ring(desc, keyring) -> Optional[int]: return _keyutils.add_key(b"keyring", desc, None, keyring) def request_key(keyDesc, keyring, keyType=b"user", callout_info=None): @@ -142,6 +142,10 @@ def dh_compute_kdf(key_priv, key_prime, key_base, hashname, buflen, otherinfo=No return _keyutils.dh_compute_kdf(key_priv, key_prime, key_base, hashname, buflen, otherinfo) +def restrict_keyring(keyring, key_type, restriction): + return _keyutils.restrict_keyring(keyring, key_type, restriction) + + def describe_key(keyId): return _keyutils.describe_key(keyId) diff --git a/keyutils/_keyutils.pyx b/keyutils/_keyutils.pyx index 4d12b78..653a2c0 100644 --- a/keyutils/_keyutils.pyx +++ b/keyutils/_keyutils.pyx @@ -292,6 +292,27 @@ def dh_compute_kdf(int key_priv, int key_prime, int key_base, bytes hashname, in return obj +def restrict_keyring(int keyring, bytes key_type, bytes restriction): + cdef int rc + cdef char *type_p + cdef char *restriction_p + + if key_type is None: + type_p = NULL + else: + type_p = key_type + + if restriction is None: + restriction_p = NULL + else: + restriction_p = restriction + + with nogil: + rc = ckeyutils.restrict_keyring(keyring, type_p, restriction_p) + _throw_err(rc) + return None + + def describe_key(int key): cdef int size cdef char *ptr diff --git a/keyutils/ckeyutils.pxd b/keyutils/ckeyutils.pxd index fb5181f..441ee28 100644 --- a/keyutils/ckeyutils.pxd +++ b/keyutils/ckeyutils.pxd @@ -76,6 +76,7 @@ cdef extern from "keyutils.h" nogil: int get_persistent "keyctl_get_persistent"(uid_t uid, key_serial_t key) int dh_compute_kdf "keyctl_dh_compute_kdf"(key_serial_t priv, key_serial_t prime, key_serial_t base, char *hashname, char *otherinfo, int otherinfolen, char *buffer, size_t buflen) int dh_compute_alloc "keyctl_dh_compute_alloc"(key_serial_t priv, key_serial_t prime, key_serial_t base, void **bufptr) + int restrict_keyring "keyctl_restrict_keyring"(key_serial_t keyring, const char *key_type, const char *restriction) int describe_alloc "keyctl_describe_alloc"(int key, char **bufptr) int read_alloc "keyctl_read_alloc"(int key, void ** bufptr) int get_security_alloc "keyctl_get_security_alloc"(key_serial_t key, char **bufptr) diff --git a/test/crypt_utils.py b/test/crypt_utils.py index 27b94ee..94709b9 100644 --- a/test/crypt_utils.py +++ b/test/crypt_utils.py @@ -1,6 +1,5 @@ """Utilities to generate stub cryptographic elements""" import base64 -import os import re import subprocess from pathlib import Path @@ -19,6 +18,7 @@ def read_pem_object(head: str, it): key_data += line.strip() raise StopIteration(f"Reached end of key block but did not find trailer. head is {key_name}") + def parse_openssl_text(ls): keys = [] result = {} @@ -40,7 +40,7 @@ def parse_openssl_text(ls): current_field = line[:field_name_end] if not line.endswith(":"): # data is inline - current_data = line[field_name_end+1:].strip() + current_data = line[field_name_end + 1:].strip() else: current_data = "" # Reset the data list for the new field @@ -57,14 +57,14 @@ def parse_openssl_text(ls): def process_openssl_objects(objs: dict): out = {} - for k,v in objs.items(): + for k, v in objs.items(): if k in {"private-key", "public-key", "P"}: # out[k] = bytes.fromhex(v.replace(":", "")) out[k] = v.replace(":", "").encode("ascii") elif re.match("\d+ \(0x\d+\)", v): # matches '2 (0x2)' match = re.match("\d+ \(0x(\d+)\)", v) out[k] = match.group(0).encode("utf-8") - out[k] = b"02" # TODO: this is a shim + out[k] = b"02" # TODO: this is a shim else: out[k] = v return out @@ -82,14 +82,15 @@ def gen_dh(workdir: Path, gen_dhparam: bool): workdir.mkdir(exist_ok=True, parents=True) subprocess.run(["openssl", "dhparam", "-check", "-out", str(workdir / "dh.pem"), "2048"]) - keyinfo = subprocess.run(["openssl", "genpkey", "-paramfile", str(workdir / "dh.pem"), "-text"], capture_output=True, text=True) + keyinfo = subprocess.run(["openssl", "genpkey", "-paramfile", str(workdir / "dh.pem"), "-text"], + capture_output=True, text=True) return keyinfo.stdout def dh_keys(workdir: Path, regen=True): openssl_output = gen_dh(workdir, regen).splitlines() - keys, objects= parse_openssl_text(openssl_output) + keys, objects = parse_openssl_text(openssl_output) objects = process_openssl_objects(objects) keys = process_key_bodies(keys) return keys, objects @@ -103,6 +104,51 @@ def extract_dh_keyring_items(keys, objects): } +def gen_rsa(workdir, regen: bool): + privkey_path = str(workdir / "rsa.pem") + if regen: + workdir.mkdir(exist_ok=True, parents=True) + subprocess.run(["openssl", "genpkey", "-algorithm", "RSA", "-out", privkey_path]) + + der_path = str(workdir / "rsa.x509.der") + subprocess.run(["openssl", "req", "-new", "-x509", "-key", privkey_path, "-outform", "DER", "-days", "365", "-out", der_path, "-subj", "/C=CA/O=example/CN=turkeyutils-ca"]) + + with open(der_path, mode="rb") as der_file: + der = der_file.read() + + return der + + +def gen_child_cert(workdir, cadir, regen: bool): + privkey_path = str(workdir / "rsa.pem") + if regen: + workdir.mkdir(exist_ok=True, parents=True) + subprocess.run(["openssl", "genpkey", "-algorithm", "RSA", "-out", privkey_path]) + + csr_path = str(workdir / "rsa.crt") + der_path = str(workdir / "rsa.x509.der") + + subprocess.run(["openssl", "req", "-new", "-key", privkey_path, "-out", csr_path, "-subj", "/C=CA/O=example/CN=turkeyutils-leaf"]) + subprocess.run(["openssl", "x509", "-req", "-in", csr_path, "-CA", str(cadir/"rsa.x509.der"), "-CAkey", str(cadir/"rsa.pem"), "-CAcreateserial", "-days", "365", "-outform", "DER", "-out", der_path]) + + with open(der_path, mode="rb") as der_file: + der = der_file.read() + + return der + + +def rsa_keys(workdir: Path, regen=True): + regen = True + ca = gen_rsa(workdir / "ca", regen) + unsigned = gen_rsa(workdir / "unsigned", regen) + leaf = gen_child_cert(workdir / "leaf", workdir / "ca", regen) + return { + "ca": ca, + "unsigned": unsigned, + "leaf": leaf, + } + + if __name__ == "__main__": keys, objects = dh_keys(Path.cwd(), regen=False) print(keys, objects) diff --git a/test/keyutils_test.py b/test/keyutils_test.py index f000862..91382c8 100644 --- a/test/keyutils_test.py +++ b/test/keyutils_test.py @@ -32,6 +32,7 @@ def ring(request): return keyutils.add_ring(request.function.__name__.encode("utf-8"), keyutils.KEY_SPEC_THREAD_KEYRING) + class BasicTest(unittest.TestCase): def testSet(self): keyDesc = b"test:key:01" @@ -194,6 +195,7 @@ def testInvalidate(self): with pytest.raises(keyutils.KeyutilsError): # TODO: more specific error check keyutils.read_key(key_id) + class TestBasic: def testGetPersistent(self, ring): @@ -209,6 +211,7 @@ def testGetSecurity(self, ring): security = keyutils.get_security(ring) assert security == b'' # TODO: find out how to apply security labels + def test_get_keyring_id(): keyring = keyutils.get_keyring_id(keyutils.KEY_SPEC_THREAD_KEYRING, False) assert keyring is not None and keyring != 0 @@ -281,7 +284,6 @@ def test_compute(self, dh_keys): assert v assert len(v) == 520 - def test_kdf(self, dh_keys): keys = {k: keyutils.add_key(k.encode("utf-8"), v, keyutils.KEY_SPEC_THREAD_KEYRING) for k, v in dh_keys.items()} @@ -290,5 +292,44 @@ def test_kdf(self, dh_keys): assert len(v) == 1024 +@pytest.fixture +def rsa_keys(tmpdir): + regen = not Path("/tmp/rsa/ca/rsa.pem").exists() + out = Path("/tmp/rsa") + + return crypt_utils.rsa_keys(out, regen=regen) + + +class TestRestrict: + def test_block_all(self, ring): + keyutils.restrict_keyring(ring, None, None) + + with pytest.raises(keyutils.KeyutilsError) as e: + keyutils.add_key(b"test_restrict_n", b"test_restrict_v", ring) + assert e.value.args[1] == 'Operation not permitted' + + def test_restrict_keyring(self, rsa_keys): + allowed_ring = keyutils.add_ring(b"test_restrict_keyring_allowed", keyutils.KEY_SPEC_THREAD_KEYRING) + allowed_key = keyutils.add_key(b"restrict_allowed", rsa_keys["ca"], allowed_ring, b"asymmetric") + target_ring = keyutils.add_ring(b"test_restrict_keyring_target", keyutils.KEY_SPEC_THREAD_KEYRING) + + print(f"key_or_keyring:{allowed_ring}") + keyutils.restrict_keyring(target_ring, b"asymmetric", f"key_or_keyring:{allowed_ring}".encode("ascii")) + + # check we can add a permitted key + keyutils.link(allowed_key, target_ring) + # check we can't add user keys + with pytest.raises(keyutils.KeyutilsError) as e: + keyutils.add_key(b"test_restrict_user", b"test_restrict_v", target_ring) + assert e.value.args[1] == 'Operation not supported' + # check we can't add a random x509 + with pytest.raises(keyutils.KeyutilsError) as e: + keyutils.add_key(b"test_restrict_unsigned", rsa_keys["unsigned"], target_ring, b"asymmetric") + assert e.value.args[1] == 'Required key not available' + # check we can add a signed key + leaf_key = keyutils.add_key(b"test_restrict_leaf", rsa_keys["leaf"], target_ring, b"asymmetric") + assert leaf_key + + if __name__ == "__main__": sys.exit(unittest.main())