Skip to content

Commit

Permalink
refactor omegaconf
Browse files Browse the repository at this point in the history
  • Loading branch information
MatteoVoges committed Feb 14, 2024
1 parent f7ed0a4 commit f9eaff9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 162 deletions.
219 changes: 72 additions & 147 deletions kapitan/inventory/inv_omegaconf/inv_omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,127 +9,41 @@
import multiprocessing as mp
import os
from copy import deepcopy
from dataclasses import dataclass, field
from time import time

import yaml
from omegaconf import ListMergeMode, OmegaConf
from omegaconf import ListMergeMode, OmegaConf, DictConfig

from kapitan import cached
from .migrate import migrate
from ..inventory import InventoryError, Inventory
from ..inventory import InventoryError, Inventory, InventoryTarget
from .resolvers import register_resolvers

logger = logging.getLogger(__name__)


class InventoryTarget:
targets_path: str
logfile: str

def __init__(self, target_name: str, target_path: str) -> None:
self.path = target_path
self.name = target_name

# compose node name
self.composed_name = (
os.path.splitext(target_path)[0].replace(self.targets_path + os.sep, "").replace("/", ".")
)

self.classes: list = []
self.parameters: dict = {}
self.classes_redundancy_check: set = set()

def _merge(self, class_parameters):
if not self.parameters:
self.parameters = class_parameters
else:
merged_parameters = OmegaConf.unsafe_merge(
class_parameters,
self.parameters,
list_merge_mode=ListMergeMode.EXTEND,
)

self.parameters = merged_parameters

def _resolve(self):
escape_interpolation_strings = False
OmegaConf.resolve(self.parameters, escape_interpolation_strings)

# remove specified keys
remove_location = "omegaconf.remove"
removed_keys = OmegaConf.select(self.parameters, remove_location, default=[])
for key in removed_keys:
OmegaConf.update(self.parameters, key, {}, merge=False)

# resolve second time and convert to object
# add throw_on_missing = True when resolving second time (--> wait for to_object support)
# reference: https://github.com/omry/omegaconf/pull/1113
OmegaConf.resolve(self.parameters, escape_interpolation_strings)
self.parameters = OmegaConf.to_container(self.parameters)

def add_metadata(self):
# append meta data (legacy: _reclass_)
_meta_ = {
"name": {
"full": self.name,
"parts": self.name.split("."),
"path": self.name.replace(".", "/"),
"short": self.name,
}
}
self.parameters["_meta_"] = _meta_
self.parameters["_reclass_"] = _meta_ # legacy


class InventoryClass:
classes_path: str = "./inventory/classes"

def __init__(self, class_path: str) -> None:
self.path = class_path
self.name = os.path.splitext(class_path)[0].replace(self.classes_path + os.sep, "").replace("/", ".")
self.parameters = {}
self.dependents = []


class OmegaConfInventory(Inventory):
classes_cache: dict = {}

# InventoryTarget.targets_path = self.targets_searchpath
# InventoryClass.classes_path = self.classes_searchpath
def render_targets(self, targets: list[InventoryTarget] = None, ignore_class_not_found: bool = False) -> None:

def inventory(self):
targets = targets or self.targets.values()
register_resolvers(self.inventory_path)
selected_targets = self.get_selected_targets()

# FEAT: add flag for multiprocessing
use_mp = True

if not use_mp:
nodes = {}
# load targets one by one
for target in selected_targets:
try:
self.load_target(target)
nodes[target.name] = {"parameters": target.parameters}
except Exception as e:
raise InventoryError(f"{target.name}: {e}")
else:
# load targets parallel
manager = mp.Manager() # perf: bottleneck --> 90 % of the inventory time

nodes = manager.dict()
mp.set_start_method("spawn", True) # platform independent
with mp.Pool(len(selected_targets)) as pool:
r = pool.map_async(
self.inventory_worker, [(self, target, nodes) for target in selected_targets]
)
r.wait()
# load targets parallel
manager = mp.Manager() # perf: bottleneck --> 90 % of the inventory time
shared_targets = manager.dict()

# using nodes for reclass legacy code
nodes = dict(nodes)
mp.set_start_method("spawn", True) # platform independent
with mp.Pool(min(len(targets), os.cpu_count())) as pool:
r = pool.map_async(self.inventory_worker, [(self, target, shared_targets) for target in targets])
r.wait()

# using nodes for reclass legacy code
return {"nodes": nodes}
# store parameters and classes
for target_name, rendered_target in rendered_inventory["nodes"].items():
self.targets[target_name].parameters = rendered_target["parameters"]
self.targets[target_name].classes = rendered_target["classes"]

@staticmethod
def inventory_worker(zipped_args):
Expand All @@ -149,45 +63,17 @@ def inventory_worker(zipped_args):
def migrate(self):
migrate(self.inventory_path)

# ----------
# private
# ----------
def get_selected_targets(self):
selected_targets = []

# loop through targets searchpath and load all targets
for root, dirs, files in os.walk(self.targets_searchpath):
for target_file in files:
# split file extension and check if yml/yaml
target_path = os.path.join(root, target_file)
target_name, ext = os.path.splitext(target_file)
if ext not in (".yml", ".yaml"):
logger.debug(f"{target_file}: targets have to be .yml or .yaml files.")
continue

# skip targets if they are not specified with -t flag
if self.targets and target_name not in self.targets:
continue

# initialize target
target = InventoryTarget(target_name, target_path)
if self.compose_node_name:
target.name = target.composed_name
selected_targets.append(target)

return selected_targets

def load_target(self, target: InventoryTarget):
def _load_target(self, target: InventoryTarget):
"""
load only one target with all its classes
"""

# load the target parameters
target.classes, target.parameters = self.load_config(target.path)
target.classes, target.parameters = self._load_file(target.path)

# load classes for targets
for class_name in target.classes:
inv_class = self.load_class(target, class_name)
inv_class = self._load_class(target, class_name)
if not inv_class:
# either redundantly defined or not found (with ignore_not_found: true)
continue
Expand All @@ -210,13 +96,14 @@ def load_target(self, target: InventoryTarget):
# add hint to kapitan.vars.target
logger.warning(f"Could not resolve target name on target {target.name}")

def load_class(self, target: InventoryTarget, class_name: str):
def _load_class(self, target: InventoryTarget, class_name: str, ignore_class_not_found: bool = False):
# resolve class path (has to be absolute)
class_path = os.path.join(self.classes_searchpath, *class_name.split("."))
if class_path in target.classes_redundancy_check:
class_path = os.path.join(self.classes_path, *class_name.split("."))
if class_name in target.classes:
logger.debug(f"{class_path}: class {class_name} is redundantly defined")
return None
target.classes_redundancy_check.add(class_path)

target.classes.append(class_name)

# search in inventory classes cache, otherwise load class
if class_name in self.classes_cache.keys():
Expand All @@ -227,24 +114,21 @@ def load_class(self, target: InventoryTarget, class_name: str):
class_path += ".yml"
elif os.path.isdir(class_path):
# search for init file
init_path = os.path.join(self.classes_searchpath, *class_name.split("."), "init") + ".yml"
init_path = os.path.join(self.classes_path, *class_name.split("."), "init") + ".yml"
if os.path.isfile(init_path):
class_path = init_path
elif self.ignore_class_notfound:
elif ignore_class_notfound:
logger.debug(f"Could not find {class_path}")
return None
else:
raise InventoryError(f"Class {class_name} not found.")

# load classes recursively
classes, parameters = self.load_config(class_path)
classes, parameters = self._load_file(class_path)

if not classes and not parameters:
return None

# initialize inventory class
inv_class = InventoryClass(class_path)
inv_class.parameters = parameters
# resolve relative class names for new classes
for c in classes:
if c.startswith("."):
Expand All @@ -256,19 +140,60 @@ def load_class(self, target: InventoryTarget, class_name: str):

return inv_class

def load_config(self, path: str):
@staticmethod
def _load_file(path: str):
with open(path, "r") as f:
f.seek(0)
config = yaml.load(f, yaml.SafeLoader)

if not config:
logger.debug(f"{path}: file is empty")
return [], {}

classes = OmegaConf.create(config.get("classes", []))
parameters = OmegaConf.create(config.get("parameters", {}))

# add metadata to nodes
# add metadata (filename, filepath) to node
filename = os.path.splitext(os.path.split(path)[1])[0]
parameters._set_flag(["filename", "path"], [filename, path], recursive=True)

return classes, parameters
return classes, parameters

@staticmethod
def _merge_parameters(target_parameters: DictConfig, class_parameters: DictConfig) -> DictConfig:
if not target_parameters:
return class_parameters

return OmegaConf.unsafe_merge(
class_parameters, target_parameters, list_merge_mode=ListMergeMode.EXTEND,
)

@staticmethod
def _resolve_parameters(target_parameters: DictConfig):
# resolve first time
OmegaConf.resolve(target_parameters, escape_interpolation_strings=False)

# remove specified keys between first and second resolve-stage
remove_location = "omegaconf.remove"
removed_keys = OmegaConf.select(target_parameters, remove_location, default=[])
for key in removed_keys:
OmegaConf.update(target_parameters, key, {}, merge=False)

# resolve second time and convert to object
# TODO: add `throw_on_missing = True` when resolving second time (--> wait for to_object support)
# reference: https://github.com/omry/omegaconf/pull/1113
OmegaConf.resolve(target_parameters, escape_interpolation_strings=False)
return OmegaConf.to_container(target_parameters)

@staticmethod
def _add_metadata(target: InventoryTarget):
# append meta data (legacy: _reclass_)
_kapitan_ = {
"name": {
"full": target.name,
"parts": target.name.split("."),
"path": target.name.replace(".", "/"),
"short": target.name,
}
}
target.parameters["_kapitan_"] = _kapitan_
target.parameters["_reclass_"] = _kapitan_ # legacy
4 changes: 2 additions & 2 deletions kapitan/inventory/inv_reclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from kapitan.errors import InventoryError

from .inventory import Inventory
from .inventory import Inventory, InventoryTarget

logger = logging.getLogger(__name__)


class ReclassInventory(Inventory):

def render_targets(self, targets: list = None, ignore_class_notfound: bool = False):
def render_targets(self, targets: list[InventoryTarget] = None, ignore_class_notfound: bool = False) -> None:
"""
Runs a reclass inventory in inventory_path
(same output as running ./reclass.py -b inv_base_uri/ --inventory)
Expand Down
20 changes: 7 additions & 13 deletions kapitan/inventory/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
class InventoryTarget:
name: str
path: str
composed_name: str
parameters: dict = field(default_factory=dict)
classes: list = field(default_factory=list)

Expand Down Expand Up @@ -63,21 +62,16 @@ def search_targets(self) -> dict:
for root, dirs, files in os.walk(self.targets_path):
for file in files:
# split file extension and check if yml/yaml
path = os.path.join(root, file)
path = os.path.relpath(os.path.join(root, file), self.targets_path)
name, ext = os.path.splitext(file)
if ext not in (".yml", ".yaml"):
logger.debug(f"{file}: targets have to be .yml or .yaml files.")
logger.error(f"{file}: targets have to be .yml or .yaml files.")
continue

# initialize target
composed_name = (
os.path.splitext(os.path.relpath(path, self.targets_path))[0]
.replace(os.sep, ".")
.lstrip(".")
)
target = InventoryTarget(name, path, composed_name)
if self.compose_target_name:
target.name = target.composed_name
name = path.replace(os.sep, ".")
target = InventoryTarget(name, path)

# check for same name
if self.targets.get(target.name):
Expand All @@ -95,7 +89,7 @@ def get_target(self, target_name: str, ignore_class_not_found: bool = False) ->
"""
return self.get_targets([target_name], ignore_class_not_found)[target_name]

def get_targets(self, target_names: list, ignore_class_not_found: bool = False) -> dict:
def get_targets(self, target_names: list[str], ignore_class_not_found: bool = False) -> dict:
"""
helper function to get rendered InventoryTarget objects for multiple targets
"""
Expand All @@ -115,7 +109,7 @@ def get_targets(self, target_names: list, ignore_class_not_found: bool = False)

return {name: target for name, target in self.targets.items() if name in target_names}

def get_parameters(self, target_names: Union[str, list], ignore_class_not_found: bool = False) -> dict:
def get_parameters(self, target_names: str | list[str], ignore_class_not_found: bool = False) -> dict:
"""
helper function to get rendered parameters for single target or multiple targets
"""
Expand All @@ -126,7 +120,7 @@ def get_parameters(self, target_names: Union[str, list], ignore_class_not_found:
return {name: target.parameters for name, target in self.get_targets(target_names)}

@abstractmethod
def render_targets(self, targets: list = None, ignore_class_notfound: bool = False):
def render_targets(self, targets: list[InventoryTarget] = None, ignore_class_notfound: bool = False) -> None:
"""
create the inventory depending on which backend gets used
"""
Expand Down

0 comments on commit f9eaff9

Please sign in to comment.