From 75ed8334e6de3f0303ac597fbf4cfee6c2fa00e9 Mon Sep 17 00:00:00 2001 From: Demi Marie Obenour Date: Sun, 14 Jan 2024 16:34:49 -0500 Subject: [PATCH] Implement UUID support in qrexec --- qrexec/policy/parser.py | 138 ++++++++++++++++++++++++---------- qrexec/tests/policy_parser.py | 52 +++++++++++-- qrexec/utils.py | 3 +- 3 files changed, 145 insertions(+), 48 deletions(-) diff --git a/qrexec/policy/parser.py b/qrexec/policy/parser.py index 8013edb2..9576b2a4 100644 --- a/qrexec/policy/parser.py +++ b/qrexec/policy/parser.py @@ -30,6 +30,7 @@ import itertools import logging import pathlib +import re import string from typing import ( @@ -48,7 +49,7 @@ ) from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX -from ..utils import FullSystemInfo +from ..utils import FullSystemInfo, SystemInfo, SystemInfoEntry from .. import exc from ..exc import ( AccessDenied, @@ -235,11 +236,19 @@ def __new__(cls, token: str, *, filepath: Optional[pathlib.Path]=None, orig_token = token # first, adjust some aliases - if token == "dom0": + if token in ("dom0", "@uuid:00000000-0000-0000-0000-000000000000"): # TODO: log a warning in Qubes 4.1 token = "@adminvm" - # if user specified just qube name, use it directly + # if user specified just qube name or UUID, use it directly + if token.startswith("@uuid:"): + if not _uuid_regex.match(token[6:]): + raise PolicySyntaxError( + filepath, + lineno or 0, + f"invalid UUID: {token[6:]!r}", + ) + return super().__new__(cls, token) if not (token.startswith("@") or token == "*"): return super().__new__(cls, token) @@ -264,7 +273,7 @@ def __new__(cls, token: str, *, filepath: Optional[pathlib.Path]=None, lineno or 0, "invalid empty {} token: {!r}".format(prefix, token), ) - if value.startswith("@"): + if value.startswith("@") and not value.startswith("@uuid:"): # we are either part of a longer prefix (@dispvm:@tag: etc), # or the token is invalid, in which case this will fallthru continue @@ -288,7 +297,7 @@ def __init__(self, token: str, *, filepath: Optional[pathlib.Path]=None, self.lineno = lineno try: self.value = self[len(self.PREFIX) :] # type: ignore - assert self.value[0] != "@" + assert self.value[0] != "@" or self.value.startswith("@uuid:") except AttributeError: # self.value = self pass @@ -300,17 +309,34 @@ def __init__(self, token: str, *, filepath: Optional[pathlib.Path]=None, # This replaces is_match() and is_match_single(). def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional["VMToken"]=None ) -> bool: """Check if this token matches opposite token""" - # pylint: disable=unused-argument + # pylint: disable=unused-argument,too-many-return-statements + if self == "@adminvm": + return other == "@adminvm" + info = system_info["domains"] + if self.startswith("@uuid:"): + if other.startswith("@uuid:"): + return self == other + try: + return self[6:] == info[str(other)]["uuid"] + except KeyError: + return False + if other.startswith("@uuid:"): + try: + return other[6:] == info[str(self)]["uuid"] + except KeyError: + return False return self == other def is_special_value(self) -> bool: """Check if the token specification is special (keyword) value""" + if self.startswith("@uuid:"): + return False return self.startswith("@") or self == "*" @property @@ -339,8 +365,9 @@ def expand(self, *, system_info: FullSystemInfo) -> Iterable[VMToken]: This is used as part of :py:meth:`Policy.collect_targets_for_ask()`. """ - if self in system_info["domains"]: - yield IntendedTarget(self) + domain = get_domain_name(system_info["domains"], self) + if domain is not None: + yield IntendedTarget(type(self)(domain)) class Target(_BaseTarget): @@ -362,10 +389,43 @@ def __new__( return super().__new__(cls, value, filepath=filepath, lineno=lineno) # type: ignore +_uuid_regex = re.compile(r"\A[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}\Z") + +# FIXME: this is O(n) for UUIDs +def get_domain(info: SystemInfo, domain: str) -> Optional[SystemInfoEntry]: + assert isinstance(domain, str) + if domain.startswith("@uuid:"): + uuid = domain[6:] + if not _uuid_regex.match(uuid): + return None + for domain_info in info.values(): + if domain_info["uuid"] == uuid: + return domain_info + return None + try: + return info[domain] + except KeyError: + return None + +# FIXME: this is O(n) for UUIDs +def get_domain_name(info: SystemInfo, domain: str) -> Optional[str]: + assert isinstance(domain, str) + if domain.startswith("@uuid:"): + uuid = domain[6:] + if not _uuid_regex.match(uuid): + return None + for name, domain_info in info.items(): + if domain_info["uuid"] == uuid: + return name + return None + if domain in info: + return domain + return None + # this method (with overloads in subclasses) was verify_target_value class IntendedTarget(VMToken): # pylint: disable=missing-docstring - def verify(self, *, system_info: FullSystemInfo) -> VMToken: + def verify(self, *, system_info: FullSystemInfo) -> Optional[VMToken]: """Check if given value names valid target This function check if given value is not only syntactically correct, @@ -387,7 +447,7 @@ def verify(self, *, system_info: FullSystemInfo) -> VMToken: if type(self) != IntendedTarget: raise NotImplementedError() - if self not in system_info["domains"]: + if get_domain(system_info["domains"], self) is None: logging.warning( "qrexec: target %r does not exist, using @default instead", str(self), @@ -410,7 +470,7 @@ class WildcardVM(Source, Target): def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional[VMToken]=None @@ -443,7 +503,7 @@ class AnyVM(Source, Target): def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional[VMToken]=None @@ -476,15 +536,15 @@ class TypeVM(Source, Target): def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional[VMToken]=None ) -> bool: - _system_info = system_info["domains"] + other_vm = get_domain(system_info["domains"], other) return ( - other in _system_info - and self.value == _system_info[other]["type"] + other_vm is not None + and self.value == other_vm["type"] ) def expand(self, *, system_info: FullSystemInfo) -> Iterable[IntendedTarget]: @@ -499,15 +559,15 @@ class TagVM(Source, Target): def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional[VMToken]=None ) -> bool: - _system_info = system_info["domains"] + other_vm = get_domain(system_info["domains"], other) return ( - other in _system_info - and self.value in _system_info[other]["tags"] + other_vm is not None + and self.value in other_vm["tags"] ) def expand(self, *, system_info: FullSystemInfo) -> Iterable[IntendedTarget]: @@ -522,7 +582,7 @@ class DispVM(Target, Redirect, IntendedTarget): def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional[VMToken]=None @@ -542,10 +602,10 @@ def get_dispvm_template( system_info: FullSystemInfo, ) -> Optional["DispVMTemplate"]: """Given source, get appropriate template for DispVM. Maybe None.""" - _system_info = system_info["domains"] - if source not in _system_info: + source_info = get_domain(system_info["domains"], source) + if source_info is None: return None - template = _system_info[source].get("default_dispvm", None) + template = source_info.get("default_dispvm", None) if template is None: return None return DispVMTemplate("@dispvm:" + template) @@ -556,7 +616,7 @@ class DispVMTemplate(Source, Target, Redirect, IntendedTarget): def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional[VMToken]=None @@ -568,16 +628,17 @@ def match( return self == other def expand(self, *, system_info: FullSystemInfo) -> Iterable["DispVMTemplate"]: - if system_info["domains"][self.value]["template_for_dispvms"]: + domain = get_domain(system_info["domains"], self.value) + assert domain is not None + if domain["template_for_dispvms"]: yield self # else: log a warning? - def verify(self, *, system_info: FullSystemInfo) -> "DispVMTemplate": - _system_info = system_info["domains"] - if ( - self.value not in _system_info - or not _system_info[self.value]["template_for_dispvms"] - ): + def verify(self, *, system_info: FullSystemInfo) -> Optional["DispVMTemplate"]: + self_info = get_domain(system_info["domains"], self.value) + if self_info is None: + return None + if not self_info["template_for_dispvms"]: raise AccessDenied( "not a template for dispvm: {}".format(self.value) ) @@ -590,22 +651,21 @@ class DispVMTag(Source, Target): def match( self, - other: Optional[str], + other: str, *, system_info: FullSystemInfo, source: Optional[VMToken]=None ) -> bool: - if isinstance(other, DispVM): - assert source is not None - other = other.get_dispvm_template(source, system_info=system_info) + if isinstance(other, DispVM) and source is not None: + return self == other.get_dispvm_template(source, system_info=system_info) if not isinstance(other, DispVMTemplate): # 1) original other may have been neither @dispvm: nor @dispvm # 2) other.get_dispvm_template() may have been None return False - domain = system_info["domains"][other.value] - if not domain["template_for_dispvms"]: + domain = get_domain(system_info["domains"], other.value) + if domain is None or not domain["template_for_dispvms"]: return False if not self.value in domain["tags"]: return False diff --git a/qrexec/tests/policy_parser.py b/qrexec/tests/policy_parser.py index e0e9a112..83501113 100644 --- a/qrexec/tests/policy_parser.py +++ b/qrexec/tests/policy_parser.py @@ -38,6 +38,7 @@ "default_dispvm": "default-dvm", "template_for_dispvms": False, "power_state": "Running", + "uuid": "00000000-0000-0000-0000-000000000000", }, "test-vm1": { "tags": ["tag1", "tag2"], @@ -45,6 +46,7 @@ "default_dispvm": "default-dvm", "template_for_dispvms": False, "power_state": "Running", + "uuid": "c9024a97-9b15-46cc-8341-38d75d5d421b", }, "test-vm2": { "tags": ["tag2"], @@ -52,6 +54,7 @@ "default_dispvm": "default-dvm", "template_for_dispvms": False, "power_state": "Running", + "uuid": "b3eb69d0-f9d9-4c3c-ad5c-454500303ea4", }, "test-vm3": { "tags": ["tag3"], @@ -59,6 +62,7 @@ "default_dispvm": "default-dvm", "template_for_dispvms": True, "power_state": "Halted", + "uuid": "fa6d56e8-a89d-4106-aa62-22e172a43c8b", }, "default-dvm": { "tags": [], @@ -66,6 +70,7 @@ "default_dispvm": "default-dvm", "template_for_dispvms": True, "power_state": "Halted", + "uuid": "f3e538bd-4427-4697-bed7-45ef3270df21", }, "test-invalid-dvm": { "tags": ["tag1", "tag2"], @@ -73,6 +78,7 @@ "default_dispvm": "test-vm1", "template_for_dispvms": False, "power_state": "Halted", + "uuid": "c4fa3586-a6b6-4dc4-bdda-c9e7375a12b5", }, "test-no-dvm": { "tags": ["tag1", "tag2"], @@ -80,6 +86,7 @@ "default_dispvm": None, "template_for_dispvms": False, "power_state": "Halted", + "uuid": "53a450b9-a454-4416-8adb-46812257ad29", }, "test-template": { "tags": ["tag1", "tag2"], @@ -87,6 +94,7 @@ "default_dispvm": "default-dvm", "template_for_dispvms": False, "power_state": "Halted", + "uuid": "a9fe2b04-9fd5-4e95-be20-162433d64de0", }, "test-standalone": { "tags": ["tag1", "tag2"], @@ -94,6 +102,7 @@ "default_dispvm": "default-dvm", "template_for_dispvms": False, "power_state": "Halted", + "uuid": "6d7a02b5-532b-467f-b9fb-6596bae03c33", }, }, } @@ -120,6 +129,11 @@ def test_010_Source(self): parser.Source("*") with self.assertRaises(exc.PolicySyntaxError): parser.Source("@default") + parser.Source("@uuid:d8a249f1-b02b-4944-a9e5-437def2fbe2c") + with self.assertRaises(exc.PolicySyntaxError): + parser.Source("@uuid") + with self.assertRaises(exc.PolicySyntaxError): + parser.Source("@uuid:invaliduuid") parser.Source("@type:AppVM") parser.Source("@tag:tag1") with self.assertRaises(exc.PolicySyntaxError): @@ -150,7 +164,12 @@ def test_020_Target(self): parser.Target("@dispvm") parser.Target("@dispvm:default-dvm") parser.Target("@dispvm:@tag:tag3") + parser.Target("@uuid:d8a249f1-b02b-4944-a9e5-437def2fbe2c") + with self.assertRaises(exc.PolicySyntaxError): + parser.Target("@uuid") + with self.assertRaises(exc.PolicySyntaxError): + parser.Target("@uuid:invaliduuid") with self.assertRaises(exc.PolicySyntaxError): parser.Target("@invalid") with self.assertRaises(exc.PolicySyntaxError): @@ -163,19 +182,22 @@ def test_020_Target(self): parser.Target("@type:") def test_021_Target_expand(self): - self.assertCountEqual( - parser.Target("test-vm1").expand(system_info=SYSTEM_INFO), + self.assertEqual( + list(parser.Target("test-vm1").expand(system_info=SYSTEM_INFO)), ["test-vm1"], ) - self.assertCountEqual( - parser.Target("@adminvm").expand(system_info=SYSTEM_INFO), + self.assertEqual( + list(parser.Target("@adminvm").expand(system_info=SYSTEM_INFO)), ["@adminvm"], ) - self.assertCountEqual( - parser.Target("dom0").expand(system_info=SYSTEM_INFO), ["@adminvm"] + self.assertEqual( + list(parser.Target("dom0").expand(system_info=SYSTEM_INFO)), ["@adminvm"] ) - self.assertCountEqual( - parser.Target("@anyvm").expand(system_info=SYSTEM_INFO), + self.assertEqual( + list(parser.Target("@uuid:00000000-0000-0000-0000-000000000000").expand(system_info=SYSTEM_INFO)), ["@adminvm"] + ) + self.assertEqual( + list(parser.Target("@anyvm").expand(system_info=SYSTEM_INFO)), [ "test-vm1", "test-vm2", @@ -286,6 +308,11 @@ def test_030_Redirect(self): parser.Redirect("test-vm1") parser.Redirect("@adminvm") parser.Redirect("dom0") + parser.Redirect("@uuid:00000000-0000-0000-0000-000000000000") + with self.assertRaises(exc.PolicySyntaxError): + parser.Redirect("@uuid") + with self.assertRaises(exc.PolicySyntaxError): + parser.Redirect("@uuid:invaliduuid") with self.assertRaises(exc.PolicySyntaxError): parser.Redirect("@anyvm") with self.assertRaises(exc.PolicySyntaxError): @@ -313,9 +340,14 @@ def test_030_Redirect(self): parser.Redirect("@type:") def test_040_IntendedTarget(self): + parser.IntendedTarget("@uuid:00000000-0000-0000-0000-000000000000") parser.IntendedTarget("test-vm1") parser.IntendedTarget("@adminvm") parser.IntendedTarget("dom0") + with self.assertRaises(exc.PolicySyntaxError): + parser.IntendedTarget("@uuid") + with self.assertRaises(exc.PolicySyntaxError): + parser.IntendedTarget("@uuid:invaliduuid") with self.assertRaises(exc.PolicySyntaxError): parser.IntendedTarget("@anyvm") with self.assertRaises(exc.PolicySyntaxError): @@ -344,6 +376,10 @@ def test_040_IntendedTarget(self): def test_100_match_single(self): # pytest: disable=no-self-use cases = [ + ("@uuid:00000000-0000-0000-0000-000000000000", "@adminvm", True), + ("@uuid:00000000-0000-0000-0000-000000000000", "dom0", True), + ("@uuid:00000000-0000-0000-0000-000000000000", "@dispvm:default-dvm", False), + ("@uuid:00000000-0000-0000-0000-000000000000", "test-vm1", False), ("@anyvm", "test-vm1", True), ("@anyvm", "@default", True), ("@default", "@default", True), diff --git a/qrexec/utils.py b/qrexec/utils.py index 348b36d0..b4dc85ba 100644 --- a/qrexec/utils.py +++ b/qrexec/utils.py @@ -115,6 +115,7 @@ class SystemInfoEntry(TypedDict): power_state: str icon: str guivm: Optional[str] + uuid: Optional[str] SystemInfo: 'TypeAlias' = Dict[str, SystemInfoEntry] @@ -136,7 +137,7 @@ def get_system_info() -> FullSystemInfo: """ system_info = qubesd_call("dom0", "internal.GetSystemInfo") - return cast(SystemInfo, json.loads(system_info.decode("utf-8"))) + return cast(FullSystemInfo, json.loads(system_info.decode("utf-8"))) def prepare_subprocess_kwds(input: object) -> Dict[str, object]: