Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataCatalog2.0]: Move pattern resolution logic - catalog cli #4130

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 118 additions & 69 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from kedro.framework.cli.utils import KedroCliError, env_option, split_string
from kedro.framework.project import pipelines, settings
from kedro.framework.session import KedroSession
from kedro.io.data_catalog import DataCatalog
from kedro.io.core import is_parameter

if TYPE_CHECKING:
from pathlib import Path

from kedro.framework.startup import ProjectMetadata
from kedro.io import AbstractDataset

NEW_CATALOG_ARG_HELP = """Use KedroDataCatalog instead of DataCatalog to run project."""


def _create_session(package_name: str, **kwargs: Any) -> KedroSession:
kwargs.setdefault("save_on_close", False)
Expand Down Expand Up @@ -49,8 +51,13 @@ def catalog() -> None:
"the project pipeline is run by default.",
callback=split_string,
)
@click.option(
"--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP
)
@click.pass_obj
def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None:
def list_datasets( # noqa: PLR0912
metadata: ProjectMetadata, pipeline: str, env: str, new_catalog: bool
) -> None:
"""Show datasets per type."""
title = "Datasets in '{}' pipeline"
not_mentioned = "Datasets not mentioned in pipeline"
Expand All @@ -61,8 +68,14 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None:
context = session.load_context()

try:
data_catalog = context.catalog
datasets_meta = data_catalog._datasets
catalog_config_resolver = None
if new_catalog:
data_catalog = context.catalog_new
datasets_meta = data_catalog.datasets
catalog_config_resolver = context.catalog_config_resolver
else:
data_catalog = context.catalog
datasets_meta = data_catalog._datasets
catalog_ds = set(data_catalog.list())
except Exception as exc:
raise KedroCliError(
Expand All @@ -86,23 +99,33 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None:
default_ds = pipeline_ds - catalog_ds
used_ds = catalog_ds - unused_ds

# resolve any factory datasets in the pipeline
factory_ds_by_type = defaultdict(list)
for ds_name in default_ds:
matched_pattern = data_catalog._match_pattern(
data_catalog._dataset_patterns, ds_name
) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name)
if matched_pattern:
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns.get(matched_pattern)
or data_catalog._default_pattern.get(matched_pattern)
or {}
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)
factory_ds_by_type[ds_config["type"]].append(ds_name)
if new_catalog:
resolved_configs = catalog_config_resolver.resolve_dataset_patterns(
default_ds
)
for ds_name, ds_config in zip(default_ds, resolved_configs):
if catalog_config_resolver.match_pattern(ds_name):
factory_ds_by_type[ds_config.get("type", "DefaultDataset")].append(
ds_name
)
else:
# resolve any factory datasets in the pipeline
for ds_name in default_ds:
matched_pattern = data_catalog._match_pattern(
data_catalog._dataset_patterns, ds_name
) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name)
if matched_pattern:
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns.get(matched_pattern)
or data_catalog._default_pattern.get(matched_pattern)
or {}
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)
factory_ds_by_type[ds_config["type"]].append(ds_name)

default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values()))

Expand All @@ -128,12 +151,11 @@ def _map_type_to_datasets(
datasets of the specific type as a value.
"""
mapping = defaultdict(list) # type: ignore[var-annotated]
for dataset in datasets:
is_param = dataset.startswith("params:") or dataset == "parameters"
if not is_param:
ds_type = datasets_meta[dataset].__class__.__name__
if dataset not in mapping[ds_type]:
mapping[ds_type].append(dataset)
for dataset_name in datasets:
if not is_parameter(dataset_name):
ds_type = datasets_meta[dataset_name].__class__.__name__
if dataset_name not in mapping[ds_type]:
mapping[ds_type].append(dataset_name)
return mapping


Expand All @@ -147,8 +169,13 @@ def _map_type_to_datasets(
required=True,
help="Name of a pipeline.",
)
@click.option(
"--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP
)
@click.pass_obj
def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> None:
def create_catalog(
metadata: ProjectMetadata, pipeline_name: str, env: str, new_catalog: bool
) -> None:
"""Create Data Catalog YAML configuration with missing datasets.

Add ``MemoryDataset`` datasets to Data Catalog YAML configuration
Expand All @@ -161,6 +188,7 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> N
env = env or "base"
session = _create_session(metadata.package_name, env=env)
context = session.load_context()
catalog = context.catalog_new if new_catalog else context.catalog

pipeline = pipelines.get(pipeline_name)

Expand All @@ -170,20 +198,16 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> N
f"'{pipeline_name}' pipeline not found! Existing pipelines: {existing_pipelines}"
)

pipe_datasets = {
ds_name
for ds_name in pipeline.datasets()
if not ds_name.startswith("params:") and ds_name != "parameters"
pipeline_datasets = {
ds_name for ds_name in pipeline.datasets() if not is_parameter(ds_name)
}

catalog_datasets = {
ds_name
for ds_name in context.catalog._datasets.keys()
if not ds_name.startswith("params:") and ds_name != "parameters"
ds_name for ds_name in catalog.list() if not is_parameter(ds_name)
}

# Datasets that are missing in Data Catalog
missing_ds = sorted(pipe_datasets - catalog_datasets)
missing_ds = sorted(pipeline_datasets - catalog_datasets)
if missing_ds:
catalog_path = (
context.project_path
Expand Down Expand Up @@ -216,69 +240,94 @@ def _add_missing_datasets_to_catalog(missing_ds: list[str], catalog_path: Path)
@catalog.command("rank")
@env_option
@click.pass_obj
def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None:
@click.option(
"--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP
)
def rank_catalog_factories(
metadata: ProjectMetadata, env: str, new_catalog: bool
) -> None:
"""List all dataset factories in the catalog, ranked by priority by which they are matched."""
session = _create_session(metadata.package_name, env=env)
context = session.load_context()

catalog_factories = {
**context.catalog._dataset_patterns,
**context.catalog._default_pattern,
}
if new_catalog:
catalog_factories = context.catalog_config_resolver.list_patterns()
else:
catalog_factories = list(
{
**context.catalog._dataset_patterns,
**context.catalog._default_pattern,
}.keys()
)
if catalog_factories:
click.echo(yaml.dump(list(catalog_factories.keys())))
click.echo(yaml.dump(catalog_factories))
else:
click.echo("There are no dataset factories in the catalog.")


@catalog.command("resolve")
@env_option
@click.pass_obj
def resolve_patterns(metadata: ProjectMetadata, env: str) -> None:
@click.option(
"--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP
)
def resolve_patterns(metadata: ProjectMetadata, env: str, new_catalog: bool) -> None:
"""Resolve catalog factories against pipeline datasets. Note that this command is runner
agnostic and thus won't take into account any default dataset creation defined in the runner."""

session = _create_session(metadata.package_name, env=env)
context = session.load_context()

catalog_config = context.config_loader["catalog"]
credentials_config = context.config_loader.get("credentials", None)
data_catalog = DataCatalog.from_config(
catalog=catalog_config, credentials=credentials_config
)
catalog_config_resolver = None
if new_catalog:
data_catalog = context.catalog_new
catalog_config_resolver = context.catalog_config_resolver
explicit_datasets = {
ds_name: ds_config
for ds_name, ds_config in data_catalog.config.items()
if not is_parameter(ds_name)
}
else:
data_catalog = context.catalog
catalog_config = context.config_loader["catalog"]

explicit_datasets = {
ds_name: ds_config
for ds_name, ds_config in catalog_config.items()
if not data_catalog._is_pattern(ds_name)
}
explicit_datasets = {
ds_name: ds_config
for ds_name, ds_config in catalog_config.items()
if not data_catalog._is_pattern(ds_name)
}

target_pipelines = pipelines.keys()
datasets = set()
pipeline_datasets = set()

for pipe in target_pipelines:
pl_obj = pipelines.get(pipe)
if pl_obj:
datasets.update(pl_obj.datasets())
pipeline_datasets.update(pl_obj.datasets())

for ds_name in datasets:
is_param = ds_name.startswith("params:") or ds_name == "parameters"
if ds_name in explicit_datasets or is_param:
for ds_name in pipeline_datasets:
if ds_name in explicit_datasets or is_parameter(ds_name):
continue

matched_pattern = data_catalog._match_pattern(
data_catalog._dataset_patterns, ds_name
) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name)
if matched_pattern:
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns.get(matched_pattern)
or data_catalog._default_pattern.get(matched_pattern)
or {}
)
if new_catalog:
ds_config = catalog_config_resolver.resolve_dataset_patterns(ds_name)
else:
ds_config = None
matched_pattern = data_catalog._match_pattern(
data_catalog._dataset_patterns, ds_name
) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name)
if matched_pattern:
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns.get(matched_pattern)
or data_catalog._default_pattern.get(matched_pattern)
or {}
)
ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)
# Exclude MemoryDatasets not set in the catalog explicitly
if ds_config is not None:
explicit_datasets[ds_name] = ds_config

secho(yaml.dump(explicit_datasets))
5 changes: 5 additions & 0 deletions kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,3 +871,8 @@ def validate_on_forbidden_chars(**kwargs: Any) -> None:
raise DatasetError(
f"Neither white-space nor semicolon are allowed in '{key}'."
)


def is_parameter(dataset_name: str) -> bool:
"""Check if dataset is a parameter."""
return dataset_name.startswith("params:") or dataset_name == "parameters"