Skip to content

Commit

Permalink
Fix/update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TheRealFalcon committed Jul 18, 2023
1 parent 16ce756 commit 9e3ea3e
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 232 deletions.
88 changes: 25 additions & 63 deletions cloudinit/distros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

import cloudinit.net.netops.iproute2 as iproute2
from cloudinit import importer
from cloudinit import helpers, importer
from cloudinit import log as logging
from cloudinit import (
net,
Expand Down Expand Up @@ -164,9 +164,7 @@ def _extract_package_by_manager(
if isinstance(entry, dict):
for key, value in entry.items():
try:
packages_by_manager[known_package_managers[key]].add(
value
)
packages_by_manager[known_package_managers[key]].add(value)
except KeyError:
LOG.error(
"Cannot install packages under %s as it is "
Expand All @@ -177,8 +175,7 @@ def _extract_package_by_manager(
generic_packages.add(entry)
else:
raise ValueError(
"Invalid 'packages' yaml specification. "
"Check schema definition."
"Invalid 'packages' yaml specification. " "Check schema definition."
)
return dict(packages_by_manager), generic_packages

Expand All @@ -199,13 +196,10 @@ def install_packages(self, pkglist):
uninstalled = []
for manager in self.package_managers:
to_try = (
packages_by_manager.get(manager.__class__, set())
| generic_packages
packages_by_manager.get(manager.__class__, set()) | generic_packages
)
uninstalled = manager.install_packages(to_try)
failed = {
pkg for pkg in uninstalled if pkg not in generic_packages
}
failed = {pkg for pkg in uninstalled if pkg not in generic_packages}
if failed:
LOG.error(error_message, failed)
generic_packages = set(uninstalled)
Expand All @@ -220,9 +214,9 @@ def install_packages(self, pkglist):
manager_name,
)
uninstalled.extend(
manager_name.from_config(
self._runner, self._cfg
).install_packages(pkglist=packages)
manager_name.from_config(self._runner, self._cfg).install_packages(
pkglist=packages
)
)

if uninstalled:
Expand All @@ -238,23 +232,17 @@ def _write_network(self, settings):
@property
def network_activator(self) -> Optional[Type[activators.NetworkActivator]]:
"""Return the configured network activator for this environment."""
priority = util.get_cfg_by_path(
self._cfg, ("network", "activators"), None
)
priority = util.get_cfg_by_path(self._cfg, ("network", "activators"), None)
try:
return activators.select_activator(priority=priority)
except activators.NoActivatorException:
return None

def _get_renderer(self) -> Renderer:
priority = util.get_cfg_by_path(
self._cfg, ("network", "renderers"), None
)
priority = util.get_cfg_by_path(self._cfg, ("network", "renderers"), None)

name, render_cls = renderers.select(priority=priority)
LOG.debug(
"Selected renderer '%s' from priority list: %s", name, priority
)
LOG.debug("Selected renderer '%s' from priority list: %s", name, priority)
renderer = render_cls(config=self.renderer_configs.get(name))
return renderer

Expand All @@ -264,9 +252,7 @@ def _write_network_state(self, network_state, renderer: Renderer):
def _find_tz_file(self, tz):
tz_file = os.path.join(self.tz_zone_dir, str(tz))
if not os.path.isfile(tz_file):
raise IOError(
"Invalid timezone %s, no file found at %s" % (tz, tz_file)
)
raise IOError("Invalid timezone %s, no file found at %s" % (tz, tz_file))
return tz_file

def get_option(self, opt_name, default=None):
Expand Down Expand Up @@ -297,9 +283,7 @@ def update_package_sources(self):
try:
manager.update_package_sources()
except Exception as e:
LOG.error(
"Failed to update package using %s: %s", manager.name, e
)
LOG.error("Failed to update package using %s: %s", manager.name, e)

def get_primary_arch(self):
arch = os.uname()[4]
Expand All @@ -317,9 +301,7 @@ def get_package_mirror_info(self, arch=None, data_source=None):
# This resolves the package_mirrors config option
# down to a single dict of {mirror_name: mirror_url}
arch_info = self._get_arch_package_mirror_info(arch)
return _get_package_mirror_info(
data_source=data_source, mirror_info=arch_info
)
return _get_package_mirror_info(data_source=data_source, mirror_info=arch_info)

def apply_network(self, settings, bring_up=True):
"""Deprecated. Remove if/when arch and gentoo support renderers."""
Expand Down Expand Up @@ -351,9 +333,7 @@ def _apply_network_from_network_config(self, netconfig, bring_up=True):
]
)
ns = network_state.parse_net_config_data(netconfig)
contents = eni.network_state_to_eni(
ns, header=header, render_hwaddress=True
)
contents = eni.network_state_to_eni(ns, header=header, render_hwaddress=True)
return self.apply_network(contents, bring_up=bring_up)

def generate_fallback_config(self):
Expand All @@ -374,9 +354,7 @@ def apply_network_config(self, netconfig, bring_up=False) -> bool:
renderer = self._get_renderer()
except NotImplementedError:
# backwards compat until all distros have apply_network_config
return self._apply_network_from_network_config(
netconfig, bring_up=bring_up
)
return self._apply_network_from_network_config(netconfig, bring_up=bring_up)

network_state = parse_net_config_data(netconfig, renderer=renderer)
self._write_network_state(network_state, renderer)
Expand All @@ -387,8 +365,7 @@ def apply_network_config(self, netconfig, bring_up=False) -> bool:
network_activator = self.network_activator
if not network_activator:
LOG.warning(
"No network activator found, not bringing up "
"network interfaces"
"No network activator found, not bringing up " "network interfaces"
)
return True
network_activator.bring_up_all_interfaces(network_state)
Expand Down Expand Up @@ -427,9 +404,7 @@ def _apply_hostname(self, hostname):
# temporarily (until reboot so it should
# not be depended on). Use the write
# hostname functions for 'permanent' adjustments.
LOG.debug(
"Non-persistently setting the system hostname to %s", hostname
)
LOG.debug("Non-persistently setting the system hostname to %s", hostname)
try:
subp.subp(["hostname", hostname])
except subp.ProcessExecutionError:
Expand Down Expand Up @@ -520,9 +495,7 @@ def update_hostname(self, hostname, fqdn, prev_hostname_fn):
try:
self._write_hostname(hostname, fn)
except IOError:
util.logexc(
LOG, "Failed to write hostname %s to %s", hostname, fn
)
util.logexc(LOG, "Failed to write hostname %s to %s", hostname, fn)

# If the system hostname file name was provided set the
# non-fqdn as the transient hostname.
Expand Down Expand Up @@ -826,9 +799,7 @@ def create_user(self, name, **kwargs):
disable_option = ssh_util.DISABLE_USER_OPTS
disable_option = disable_option.replace("$USER", redirect_user)
disable_option = disable_option.replace("$DISABLE_USER", name)
ssh_util.setup_user_keys(
set(cloud_keys), name, options=disable_option
)
ssh_util.setup_user_keys(set(cloud_keys), name, options=disable_option)
return True

def lock_passwd(self, name):
Expand Down Expand Up @@ -877,9 +848,7 @@ def set_passwd(self, user, passwd, hashed=False):

def chpasswd(self, plist_in: list, hashed: bool):
payload = (
"\n".join(
(":".join([name, password]) for name, password in plist_in)
)
"\n".join((":".join([name, password]) for name, password in plist_in))
+ "\n"
)
cmd = ["chpasswd"] + (["-e"] if hashed else [])
Expand Down Expand Up @@ -1031,9 +1000,7 @@ def reload_init(cls, rcs=None):
return subp.subp(cmd, capture=True, rcs=rcs)

@classmethod
def manage_service(
cls, action: str, service: str, *extra_args: str, rcs=None
):
def manage_service(cls, action: str, service: str, *extra_args: str, rcs=None):
"""
Perform the requested action on a service. This handles the common
'systemctl' and 'service' cases and may be overridden in subclasses
Expand Down Expand Up @@ -1220,13 +1187,9 @@ def _sanitize_mirror_url(url: str):
# decode with ASCII so we return a `str`
lambda hostname: hostname.encode("idna").decode("ascii"),
# Replace any unacceptable characters with "-"
lambda hostname: "".join(
c if c in acceptable_chars else "-" for c in hostname
),
lambda hostname: "".join(c if c in acceptable_chars else "-" for c in hostname),
# Drop leading/trailing hyphens from each part of the hostname
lambda hostname: ".".join(
part.strip("-") for part in hostname.split(".")
),
lambda hostname: ".".join(part.strip("-") for part in hostname.split(".")),
]

return _apply_hostname_transformations_to_url(url, transformations)
Expand Down Expand Up @@ -1299,8 +1262,7 @@ def fetch(name: str) -> Type[Distro]:
locs, looked_locs = importer.find_module(name, ["", __name__], ["Distro"])
if not locs:
raise ImportError(
"No distribution found for distro %s (searched %s)"
% (name, looked_locs)
"No distribution found for distro %s (searched %s)" % (name, looked_locs)
)
mod = importer.import_module(locs[0])
cls = getattr(mod, "Distro")
Expand Down
19 changes: 5 additions & 14 deletions cloudinit/distros/package_management/apt.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def get_apt_wrapper(cfg: Optional[dict]) -> List[str]:
if not cfg:
enabled = "auto"
command = ["eatmydata"]
# return bool(subp.which("eatmydata")), ["eatmydata"]
else:
enabled = cfg.get("enabled")
command = cfg.get("command")
Expand All @@ -61,9 +60,7 @@ def get_apt_wrapper(cfg: Optional[dict]) -> List[str]:
raise TypeError("apt_wrapper command must be a string or list")

if util.is_true(enabled) or (
str(enabled).lower() == "auto"
and command
and subp.which(command[0])
str(enabled).lower() == "auto" and command and subp.which(command[0])
):
return command
else:
Expand All @@ -86,21 +83,17 @@ def __init__(
apt_get_command = APT_GET_COMMAND
if apt_get_upgrade_subcommand is None:
apt_get_upgrade_subcommand = "dist-upgrade"
self.apt_command = tuple(apt_get_wrapper_command) + tuple(
apt_get_command
)
self.apt_command = tuple(apt_get_wrapper_command) + tuple(apt_get_command)

self.apt_get_upgrade_subcommand = apt_get_upgrade_subcommand
self.environment = os.environ.copy()
self.environment["DEBIAN_FRONTEND"] = "noninteracitve"
self.environment["DEBIAN_FRONTEND"] = "noninteractive"

@classmethod
def from_config(cls, runner: helpers.Runners, cfg: Mapping) -> "Apt":
return Apt(
runner,
apt_get_wrapper_command=get_apt_wrapper(
cfg.get("apt_get_wrapper")
),
apt_get_wrapper_command=get_apt_wrapper(cfg.get("apt_get_wrapper")),
apt_get_command=cfg.get("apt_get_command"),
apt_get_upgrade_subcommand=cfg.get("apt_get_upgrade_subcommand"),
)
Expand All @@ -126,9 +119,7 @@ def get_all_packages(self):
def get_unavailable_packages(self, pkglist: Iterable[str]):
return [pkg for pkg in pkglist if pkg not in self.get_all_packages()]

def install_packages(
self, pkglist: Iterable[str]
) -> UninstalledPackages:
def install_packages(self, pkglist: Iterable[str]) -> UninstalledPackages:
self.update_package_sources()
unavailable = self.get_unavailable_packages(pkglist)
to_install = [p for p in pkglist if p not in unavailable]
Expand Down
7 changes: 5 additions & 2 deletions cloudinit/distros/ubuntu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ def __init__(self, name, cfg, paths):
self.package_managers.append(self.snap)

def package_command(self, command, args=None, pkgs=None):
super().package_command(command, args, pkgs)
self.snap.upgrade_packages()
if command == 'upgrade':
super().package_command(command, args, pkgs)
self.snap.upgrade_packages()
else:
raise RuntimeError(f"Unable to handle {command} command")

@property
def preferred_ntp_clients(self):
Expand Down
Loading

0 comments on commit 9e3ea3e

Please sign in to comment.