Skip to content

Commit

Permalink
Improve OmegaConfigLoader performance (#4367)
Browse files Browse the repository at this point in the history
* improve global and runtime param resolver

* testing with slack suggestions

* fix type lint issue

* update release note

Signed-off-by: ravi_kumar_pilla <[email protected]>

---------

Signed-off-by: ravi_kumar_pilla <[email protected]>
  • Loading branch information
ravi-kumar-pilla authored Jan 3, 2025
1 parent 70734ce commit 057de34
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
* Implemented `KedroDataCatalog.to_config()` method that converts the catalog instance into a configuration format suitable for serialization.
* Improve OmegaConfigLoader performance

## Bug fixes and other changes
* Added validation to ensure dataset versions consistency across catalog.
Expand Down
25 changes: 15 additions & 10 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def __init__( # noqa: PLR0913
self.base_env = base_env or ""
self.default_run_env = default_run_env or ""
self.merge_strategy = merge_strategy or {}

self._globals_oc: DictConfig | None = None
self._runtime_params_oc: DictConfig | None = None
self.config_patterns = {
"catalog": ["catalog*", "catalog*/**", "**/catalog*"],
"parameters": ["parameters*", "parameters*/**", "**/parameters*"],
Expand Down Expand Up @@ -346,12 +347,11 @@ def load_and_merge_dir_config(
OmegaConf.merge(*aggregate_config, self.runtime_params), resolve=True
)

merged_config_container = OmegaConf.to_container(
OmegaConf.merge(*aggregate_config), resolve=True
)
return {
k: v
for k, v in OmegaConf.to_container(
OmegaConf.merge(*aggregate_config), resolve=True
).items()
if not k.startswith("_")
k: v for k, v in merged_config_container.items() if not k.startswith("_")
}

@staticmethod
Expand Down Expand Up @@ -436,9 +436,12 @@ def _get_globals_value(self, variable: str, default_value: Any = _NO_VALUE) -> A
raise InterpolationResolutionError(
"Keys starting with '_' are not supported for globals."
)
globals_oc = OmegaConf.create(self._globals)

if not self._globals_oc:
self._globals_oc = OmegaConf.create(self._globals)

interpolated_value = OmegaConf.select(
globals_oc, variable, default=default_value
self._globals_oc, variable, default=default_value
)
if interpolated_value != _NO_VALUE:
return interpolated_value
Expand All @@ -449,9 +452,11 @@ def _get_globals_value(self, variable: str, default_value: Any = _NO_VALUE) -> A

def _get_runtime_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any:
"""Return the runtime params values to the resolver"""
runtime_oc = OmegaConf.create(self.runtime_params)
if not self._runtime_params_oc:
self._runtime_params_oc = OmegaConf.create(self.runtime_params)

interpolated_value = OmegaConf.select(
runtime_oc, variable, default=default_value
self._runtime_params_oc, variable, default=default_value
)
if interpolated_value != _NO_VALUE:
return interpolated_value
Expand Down

0 comments on commit 057de34

Please sign in to comment.