From f9eaff96e2f9707428a6b067f2c9dda1ffb36c14 Mon Sep 17 00:00:00 2001 From: Matteo Voges Date: Wed, 14 Feb 2024 20:12:17 +0100 Subject: [PATCH] refactor omegaconf --- .../inventory/inv_omegaconf/inv_omegaconf.py | 219 ++++++------------ kapitan/inventory/inv_reclass.py | 4 +- kapitan/inventory/inventory.py | 20 +- 3 files changed, 81 insertions(+), 162 deletions(-) diff --git a/kapitan/inventory/inv_omegaconf/inv_omegaconf.py b/kapitan/inventory/inv_omegaconf/inv_omegaconf.py index 146ece4ae..db393d130 100644 --- a/kapitan/inventory/inv_omegaconf/inv_omegaconf.py +++ b/kapitan/inventory/inv_omegaconf/inv_omegaconf.py @@ -9,127 +9,41 @@ import multiprocessing as mp import os from copy import deepcopy +from dataclasses import dataclass, field from time import time import yaml -from omegaconf import ListMergeMode, OmegaConf +from omegaconf import ListMergeMode, OmegaConf, DictConfig from kapitan import cached from .migrate import migrate -from ..inventory import InventoryError, Inventory +from ..inventory import InventoryError, Inventory, InventoryTarget from .resolvers import register_resolvers logger = logging.getLogger(__name__) -class InventoryTarget: - targets_path: str - logfile: str - - def __init__(self, target_name: str, target_path: str) -> None: - self.path = target_path - self.name = target_name - - # compose node name - self.composed_name = ( - os.path.splitext(target_path)[0].replace(self.targets_path + os.sep, "").replace("/", ".") - ) - - self.classes: list = [] - self.parameters: dict = {} - self.classes_redundancy_check: set = set() - - def _merge(self, class_parameters): - if not self.parameters: - self.parameters = class_parameters - else: - merged_parameters = OmegaConf.unsafe_merge( - class_parameters, - self.parameters, - list_merge_mode=ListMergeMode.EXTEND, - ) - - self.parameters = merged_parameters - - def _resolve(self): - escape_interpolation_strings = False - OmegaConf.resolve(self.parameters, escape_interpolation_strings) - - # remove specified keys - remove_location = "omegaconf.remove" - removed_keys = OmegaConf.select(self.parameters, remove_location, default=[]) - for key in removed_keys: - OmegaConf.update(self.parameters, key, {}, merge=False) - - # resolve second time and convert to object - # add throw_on_missing = True when resolving second time (--> wait for to_object support) - # reference: https://github.com/omry/omegaconf/pull/1113 - OmegaConf.resolve(self.parameters, escape_interpolation_strings) - self.parameters = OmegaConf.to_container(self.parameters) - - def add_metadata(self): - # append meta data (legacy: _reclass_) - _meta_ = { - "name": { - "full": self.name, - "parts": self.name.split("."), - "path": self.name.replace(".", "/"), - "short": self.name, - } - } - self.parameters["_meta_"] = _meta_ - self.parameters["_reclass_"] = _meta_ # legacy - - -class InventoryClass: - classes_path: str = "./inventory/classes" - - def __init__(self, class_path: str) -> None: - self.path = class_path - self.name = os.path.splitext(class_path)[0].replace(self.classes_path + os.sep, "").replace("/", ".") - self.parameters = {} - self.dependents = [] - - class OmegaConfInventory(Inventory): classes_cache: dict = {} - # InventoryTarget.targets_path = self.targets_searchpath - # InventoryClass.classes_path = self.classes_searchpath + def render_targets(self, targets: list[InventoryTarget] = None, ignore_class_not_found: bool = False) -> None: - def inventory(self): + targets = targets or self.targets.values() register_resolvers(self.inventory_path) - selected_targets = self.get_selected_targets() - - # FEAT: add flag for multiprocessing - use_mp = True - - if not use_mp: - nodes = {} - # load targets one by one - for target in selected_targets: - try: - self.load_target(target) - nodes[target.name] = {"parameters": target.parameters} - except Exception as e: - raise InventoryError(f"{target.name}: {e}") - else: - # load targets parallel - manager = mp.Manager() # perf: bottleneck --> 90 % of the inventory time - nodes = manager.dict() - mp.set_start_method("spawn", True) # platform independent - with mp.Pool(len(selected_targets)) as pool: - r = pool.map_async( - self.inventory_worker, [(self, target, nodes) for target in selected_targets] - ) - r.wait() + # load targets parallel + manager = mp.Manager() # perf: bottleneck --> 90 % of the inventory time + shared_targets = manager.dict() - # using nodes for reclass legacy code - nodes = dict(nodes) + mp.set_start_method("spawn", True) # platform independent + with mp.Pool(min(len(targets), os.cpu_count())) as pool: + r = pool.map_async(self.inventory_worker, [(self, target, shared_targets) for target in targets]) + r.wait() - # using nodes for reclass legacy code - return {"nodes": nodes} + # store parameters and classes + for target_name, rendered_target in rendered_inventory["nodes"].items(): + self.targets[target_name].parameters = rendered_target["parameters"] + self.targets[target_name].classes = rendered_target["classes"] @staticmethod def inventory_worker(zipped_args): @@ -149,45 +63,17 @@ def inventory_worker(zipped_args): def migrate(self): migrate(self.inventory_path) - # ---------- - # private - # ---------- - def get_selected_targets(self): - selected_targets = [] - - # loop through targets searchpath and load all targets - for root, dirs, files in os.walk(self.targets_searchpath): - for target_file in files: - # split file extension and check if yml/yaml - target_path = os.path.join(root, target_file) - target_name, ext = os.path.splitext(target_file) - if ext not in (".yml", ".yaml"): - logger.debug(f"{target_file}: targets have to be .yml or .yaml files.") - continue - - # skip targets if they are not specified with -t flag - if self.targets and target_name not in self.targets: - continue - - # initialize target - target = InventoryTarget(target_name, target_path) - if self.compose_node_name: - target.name = target.composed_name - selected_targets.append(target) - - return selected_targets - - def load_target(self, target: InventoryTarget): + def _load_target(self, target: InventoryTarget): """ load only one target with all its classes """ # load the target parameters - target.classes, target.parameters = self.load_config(target.path) + target.classes, target.parameters = self._load_file(target.path) # load classes for targets for class_name in target.classes: - inv_class = self.load_class(target, class_name) + inv_class = self._load_class(target, class_name) if not inv_class: # either redundantly defined or not found (with ignore_not_found: true) continue @@ -210,13 +96,14 @@ def load_target(self, target: InventoryTarget): # add hint to kapitan.vars.target logger.warning(f"Could not resolve target name on target {target.name}") - def load_class(self, target: InventoryTarget, class_name: str): + def _load_class(self, target: InventoryTarget, class_name: str, ignore_class_not_found: bool = False): # resolve class path (has to be absolute) - class_path = os.path.join(self.classes_searchpath, *class_name.split(".")) - if class_path in target.classes_redundancy_check: + class_path = os.path.join(self.classes_path, *class_name.split(".")) + if class_name in target.classes: logger.debug(f"{class_path}: class {class_name} is redundantly defined") return None - target.classes_redundancy_check.add(class_path) + + target.classes.append(class_name) # search in inventory classes cache, otherwise load class if class_name in self.classes_cache.keys(): @@ -227,24 +114,21 @@ def load_class(self, target: InventoryTarget, class_name: str): class_path += ".yml" elif os.path.isdir(class_path): # search for init file - init_path = os.path.join(self.classes_searchpath, *class_name.split("."), "init") + ".yml" + init_path = os.path.join(self.classes_path, *class_name.split("."), "init") + ".yml" if os.path.isfile(init_path): class_path = init_path - elif self.ignore_class_notfound: + elif ignore_class_notfound: logger.debug(f"Could not find {class_path}") return None else: raise InventoryError(f"Class {class_name} not found.") # load classes recursively - classes, parameters = self.load_config(class_path) + classes, parameters = self._load_file(class_path) if not classes and not parameters: return None - # initialize inventory class - inv_class = InventoryClass(class_path) - inv_class.parameters = parameters # resolve relative class names for new classes for c in classes: if c.startswith("."): @@ -256,19 +140,60 @@ def load_class(self, target: InventoryTarget, class_name: str): return inv_class - def load_config(self, path: str): + @staticmethod + def _load_file(path: str): with open(path, "r") as f: f.seek(0) config = yaml.load(f, yaml.SafeLoader) if not config: - logger.debug(f"{path}: file is empty") return [], {} + classes = OmegaConf.create(config.get("classes", [])) parameters = OmegaConf.create(config.get("parameters", {})) - # add metadata to nodes + # add metadata (filename, filepath) to node filename = os.path.splitext(os.path.split(path)[1])[0] parameters._set_flag(["filename", "path"], [filename, path], recursive=True) - return classes, parameters \ No newline at end of file + return classes, parameters + + @staticmethod + def _merge_parameters(target_parameters: DictConfig, class_parameters: DictConfig) -> DictConfig: + if not target_parameters: + return class_parameters + + return OmegaConf.unsafe_merge( + class_parameters, target_parameters, list_merge_mode=ListMergeMode.EXTEND, + ) + + @staticmethod + def _resolve_parameters(target_parameters: DictConfig): + # resolve first time + OmegaConf.resolve(target_parameters, escape_interpolation_strings=False) + + # remove specified keys between first and second resolve-stage + remove_location = "omegaconf.remove" + removed_keys = OmegaConf.select(target_parameters, remove_location, default=[]) + for key in removed_keys: + OmegaConf.update(target_parameters, key, {}, merge=False) + + # resolve second time and convert to object + # TODO: add `throw_on_missing = True` when resolving second time (--> wait for to_object support) + # reference: https://github.com/omry/omegaconf/pull/1113 + OmegaConf.resolve(target_parameters, escape_interpolation_strings=False) + return OmegaConf.to_container(target_parameters) + + @staticmethod + def _add_metadata(target: InventoryTarget): + # append meta data (legacy: _reclass_) + _kapitan_ = { + "name": { + "full": target.name, + "parts": target.name.split("."), + "path": target.name.replace(".", "/"), + "short": target.name, + } + } + target.parameters["_kapitan_"] = _kapitan_ + target.parameters["_reclass_"] = _kapitan_ # legacy diff --git a/kapitan/inventory/inv_reclass.py b/kapitan/inventory/inv_reclass.py index d1fc73891..d6f5ca2b2 100644 --- a/kapitan/inventory/inv_reclass.py +++ b/kapitan/inventory/inv_reclass.py @@ -8,14 +8,14 @@ from kapitan.errors import InventoryError -from .inventory import Inventory +from .inventory import Inventory, InventoryTarget logger = logging.getLogger(__name__) class ReclassInventory(Inventory): - def render_targets(self, targets: list = None, ignore_class_notfound: bool = False): + def render_targets(self, targets: list[InventoryTarget] = None, ignore_class_notfound: bool = False) -> None: """ Runs a reclass inventory in inventory_path (same output as running ./reclass.py -b inv_base_uri/ --inventory) diff --git a/kapitan/inventory/inventory.py b/kapitan/inventory/inventory.py index 7cab807d5..453a1b955 100644 --- a/kapitan/inventory/inventory.py +++ b/kapitan/inventory/inventory.py @@ -22,7 +22,6 @@ class InventoryTarget: name: str path: str - composed_name: str parameters: dict = field(default_factory=dict) classes: list = field(default_factory=list) @@ -63,21 +62,16 @@ def search_targets(self) -> dict: for root, dirs, files in os.walk(self.targets_path): for file in files: # split file extension and check if yml/yaml - path = os.path.join(root, file) + path = os.path.relpath(os.path.join(root, file), self.targets_path) name, ext = os.path.splitext(file) if ext not in (".yml", ".yaml"): - logger.debug(f"{file}: targets have to be .yml or .yaml files.") + logger.error(f"{file}: targets have to be .yml or .yaml files.") continue # initialize target - composed_name = ( - os.path.splitext(os.path.relpath(path, self.targets_path))[0] - .replace(os.sep, ".") - .lstrip(".") - ) - target = InventoryTarget(name, path, composed_name) if self.compose_target_name: - target.name = target.composed_name + name = path.replace(os.sep, ".") + target = InventoryTarget(name, path) # check for same name if self.targets.get(target.name): @@ -95,7 +89,7 @@ def get_target(self, target_name: str, ignore_class_not_found: bool = False) -> """ return self.get_targets([target_name], ignore_class_not_found)[target_name] - def get_targets(self, target_names: list, ignore_class_not_found: bool = False) -> dict: + def get_targets(self, target_names: list[str], ignore_class_not_found: bool = False) -> dict: """ helper function to get rendered InventoryTarget objects for multiple targets """ @@ -115,7 +109,7 @@ def get_targets(self, target_names: list, ignore_class_not_found: bool = False) return {name: target for name, target in self.targets.items() if name in target_names} - def get_parameters(self, target_names: Union[str, list], ignore_class_not_found: bool = False) -> dict: + def get_parameters(self, target_names: str | list[str], ignore_class_not_found: bool = False) -> dict: """ helper function to get rendered parameters for single target or multiple targets """ @@ -126,7 +120,7 @@ def get_parameters(self, target_names: Union[str, list], ignore_class_not_found: return {name: target.parameters for name, target in self.get_targets(target_names)} @abstractmethod - def render_targets(self, targets: list = None, ignore_class_notfound: bool = False): + def render_targets(self, targets: list[InventoryTarget] = None, ignore_class_notfound: bool = False) -> None: """ create the inventory depending on which backend gets used """