Skip to content

Commit

Permalink
Implement UUID support in qrexec
Browse files Browse the repository at this point in the history
  • Loading branch information
DemiMarie committed Jan 23, 2024
1 parent 171a681 commit 75ed833
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 48 deletions.
138 changes: 99 additions & 39 deletions qrexec/policy/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import itertools
import logging
import pathlib
import re
import string

from typing import (
Expand All @@ -48,7 +49,7 @@
)

from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX
from ..utils import FullSystemInfo
from ..utils import FullSystemInfo, SystemInfo, SystemInfoEntry
from .. import exc
from ..exc import (
AccessDenied,
Expand Down Expand Up @@ -235,11 +236,19 @@ def __new__(cls, token: str, *, filepath: Optional[pathlib.Path]=None,
orig_token = token

# first, adjust some aliases
if token == "dom0":
if token in ("dom0", "@uuid:00000000-0000-0000-0000-000000000000"):
# TODO: log a warning in Qubes 4.1
token = "@adminvm"

# if user specified just qube name, use it directly
# if user specified just qube name or UUID, use it directly
if token.startswith("@uuid:"):
if not _uuid_regex.match(token[6:]):
raise PolicySyntaxError(
filepath,
lineno or 0,
f"invalid UUID: {token[6:]!r}",
)
return super().__new__(cls, token)
if not (token.startswith("@") or token == "*"):
return super().__new__(cls, token)

Expand All @@ -264,7 +273,7 @@ def __new__(cls, token: str, *, filepath: Optional[pathlib.Path]=None,
lineno or 0,
"invalid empty {} token: {!r}".format(prefix, token),
)
if value.startswith("@"):
if value.startswith("@") and not value.startswith("@uuid:"):
# we are either part of a longer prefix (@dispvm:@tag: etc),
# or the token is invalid, in which case this will fallthru
continue
Expand All @@ -288,7 +297,7 @@ def __init__(self, token: str, *, filepath: Optional[pathlib.Path]=None,
self.lineno = lineno
try:
self.value = self[len(self.PREFIX) :] # type: ignore
assert self.value[0] != "@"
assert self.value[0] != "@" or self.value.startswith("@uuid:")
except AttributeError:
# self.value = self
pass
Expand All @@ -300,17 +309,34 @@ def __init__(self, token: str, *, filepath: Optional[pathlib.Path]=None,
# This replaces is_match() and is_match_single().
def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional["VMToken"]=None
) -> bool:
"""Check if this token matches opposite token"""
# pylint: disable=unused-argument
# pylint: disable=unused-argument,too-many-return-statements
if self == "@adminvm":
return other == "@adminvm"
info = system_info["domains"]
if self.startswith("@uuid:"):
if other.startswith("@uuid:"):
return self == other
try:
return self[6:] == info[str(other)]["uuid"]
except KeyError:
return False
if other.startswith("@uuid:"):
try:
return other[6:] == info[str(self)]["uuid"]
except KeyError:
return False
return self == other

def is_special_value(self) -> bool:
"""Check if the token specification is special (keyword) value"""
if self.startswith("@uuid:"):
return False
return self.startswith("@") or self == "*"

@property
Expand Down Expand Up @@ -339,8 +365,9 @@ def expand(self, *, system_info: FullSystemInfo) -> Iterable[VMToken]:
This is used as part of :py:meth:`Policy.collect_targets_for_ask()`.
"""
if self in system_info["domains"]:
yield IntendedTarget(self)
domain = get_domain_name(system_info["domains"], self)
if domain is not None:
yield IntendedTarget(type(self)(domain))


class Target(_BaseTarget):
Expand All @@ -362,10 +389,43 @@ def __new__(
return super().__new__(cls, value, filepath=filepath, lineno=lineno) # type: ignore


_uuid_regex = re.compile(r"\A[0-9a-f]{8}(?:-[0-9a-f]{4}){3}-[0-9a-f]{12}\Z")

# FIXME: this is O(n) for UUIDs
def get_domain(info: SystemInfo, domain: str) -> Optional[SystemInfoEntry]:
assert isinstance(domain, str)
if domain.startswith("@uuid:"):
uuid = domain[6:]
if not _uuid_regex.match(uuid):
return None
for domain_info in info.values():
if domain_info["uuid"] == uuid:
return domain_info
return None
try:
return info[domain]
except KeyError:
return None

# FIXME: this is O(n) for UUIDs
def get_domain_name(info: SystemInfo, domain: str) -> Optional[str]:
assert isinstance(domain, str)
if domain.startswith("@uuid:"):
uuid = domain[6:]
if not _uuid_regex.match(uuid):
return None
for name, domain_info in info.items():
if domain_info["uuid"] == uuid:
return name
return None
if domain in info:
return domain
return None

# this method (with overloads in subclasses) was verify_target_value
class IntendedTarget(VMToken):
# pylint: disable=missing-docstring
def verify(self, *, system_info: FullSystemInfo) -> VMToken:
def verify(self, *, system_info: FullSystemInfo) -> Optional[VMToken]:
"""Check if given value names valid target
This function check if given value is not only syntactically correct,
Expand All @@ -387,7 +447,7 @@ def verify(self, *, system_info: FullSystemInfo) -> VMToken:
if type(self) != IntendedTarget:
raise NotImplementedError()

if self not in system_info["domains"]:
if get_domain(system_info["domains"], self) is None:
logging.warning(
"qrexec: target %r does not exist, using @default instead",
str(self),
Expand All @@ -410,7 +470,7 @@ class WildcardVM(Source, Target):

def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional[VMToken]=None
Expand Down Expand Up @@ -443,7 +503,7 @@ class AnyVM(Source, Target):

def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional[VMToken]=None
Expand Down Expand Up @@ -476,15 +536,15 @@ class TypeVM(Source, Target):

def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional[VMToken]=None
) -> bool:
_system_info = system_info["domains"]
other_vm = get_domain(system_info["domains"], other)
return (
other in _system_info
and self.value == _system_info[other]["type"]
other_vm is not None
and self.value == other_vm["type"]
)

def expand(self, *, system_info: FullSystemInfo) -> Iterable[IntendedTarget]:
Expand All @@ -499,15 +559,15 @@ class TagVM(Source, Target):

def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional[VMToken]=None
) -> bool:
_system_info = system_info["domains"]
other_vm = get_domain(system_info["domains"], other)
return (
other in _system_info
and self.value in _system_info[other]["tags"]
other_vm is not None
and self.value in other_vm["tags"]
)

def expand(self, *, system_info: FullSystemInfo) -> Iterable[IntendedTarget]:
Expand All @@ -522,7 +582,7 @@ class DispVM(Target, Redirect, IntendedTarget):

def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional[VMToken]=None
Expand All @@ -542,10 +602,10 @@ def get_dispvm_template(
system_info: FullSystemInfo,
) -> Optional["DispVMTemplate"]:
"""Given source, get appropriate template for DispVM. Maybe None."""
_system_info = system_info["domains"]
if source not in _system_info:
source_info = get_domain(system_info["domains"], source)
if source_info is None:
return None
template = _system_info[source].get("default_dispvm", None)
template = source_info.get("default_dispvm", None)
if template is None:
return None
return DispVMTemplate("@dispvm:" + template)
Expand All @@ -556,7 +616,7 @@ class DispVMTemplate(Source, Target, Redirect, IntendedTarget):

def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional[VMToken]=None
Expand All @@ -568,16 +628,17 @@ def match(
return self == other

def expand(self, *, system_info: FullSystemInfo) -> Iterable["DispVMTemplate"]:
if system_info["domains"][self.value]["template_for_dispvms"]:
domain = get_domain(system_info["domains"], self.value)
assert domain is not None
if domain["template_for_dispvms"]:
yield self
# else: log a warning?

def verify(self, *, system_info: FullSystemInfo) -> "DispVMTemplate":
_system_info = system_info["domains"]
if (
self.value not in _system_info
or not _system_info[self.value]["template_for_dispvms"]
):
def verify(self, *, system_info: FullSystemInfo) -> Optional["DispVMTemplate"]:
self_info = get_domain(system_info["domains"], self.value)
if self_info is None:
return None
if not self_info["template_for_dispvms"]:
raise AccessDenied(
"not a template for dispvm: {}".format(self.value)
)
Expand All @@ -590,22 +651,21 @@ class DispVMTag(Source, Target):

def match(
self,
other: Optional[str],
other: str,
*,
system_info: FullSystemInfo,
source: Optional[VMToken]=None
) -> bool:
if isinstance(other, DispVM):
assert source is not None
other = other.get_dispvm_template(source, system_info=system_info)
if isinstance(other, DispVM) and source is not None:
return self == other.get_dispvm_template(source, system_info=system_info)

if not isinstance(other, DispVMTemplate):
# 1) original other may have been neither @dispvm:<name> nor @dispvm
# 2) other.get_dispvm_template() may have been None
return False

domain = system_info["domains"][other.value]
if not domain["template_for_dispvms"]:
domain = get_domain(system_info["domains"], other.value)
if domain is None or not domain["template_for_dispvms"]:
return False
if not self.value in domain["tags"]:
return False
Expand Down
Loading

0 comments on commit 75ed833

Please sign in to comment.