diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 223980dade..973443856e 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -14,7 +14,7 @@ from kedro.framework.cli.utils import KedroCliError, env_option, split_string from kedro.framework.project import pipelines, settings from kedro.framework.session import KedroSession -from kedro.io.data_catalog import DataCatalog +from kedro.io.core import is_parameter if TYPE_CHECKING: from pathlib import Path @@ -22,6 +22,8 @@ from kedro.framework.startup import ProjectMetadata from kedro.io import AbstractDataset +NEW_CATALOG_ARG_HELP = """Use KedroDataCatalog instead of DataCatalog to run project.""" + def _create_session(package_name: str, **kwargs: Any) -> KedroSession: kwargs.setdefault("save_on_close", False) @@ -49,8 +51,13 @@ def catalog() -> None: "the project pipeline is run by default.", callback=split_string, ) +@click.option( + "--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP +) @click.pass_obj -def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None: +def list_datasets( # noqa: PLR0912 + metadata: ProjectMetadata, pipeline: str, env: str, new_catalog: bool +) -> None: """Show datasets per type.""" title = "Datasets in '{}' pipeline" not_mentioned = "Datasets not mentioned in pipeline" @@ -61,8 +68,14 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None: context = session.load_context() try: - data_catalog = context.catalog - datasets_meta = data_catalog._datasets + catalog_config_resolver = None + if new_catalog: + data_catalog = context.catalog_new + datasets_meta = data_catalog.datasets + catalog_config_resolver = context.catalog_config_resolver + else: + data_catalog = context.catalog + datasets_meta = data_catalog._datasets catalog_ds = set(data_catalog.list()) except Exception as exc: raise KedroCliError( @@ -86,23 +99,33 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None: default_ds = pipeline_ds - catalog_ds used_ds = catalog_ds - unused_ds - # resolve any factory datasets in the pipeline factory_ds_by_type = defaultdict(list) - for ds_name in default_ds: - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} - ) - - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy - ) - factory_ds_by_type[ds_config["type"]].append(ds_name) + if new_catalog: + resolved_configs = catalog_config_resolver.resolve_dataset_patterns( + default_ds + ) + for ds_name, ds_config in zip(default_ds, resolved_configs): + if catalog_config_resolver.match_pattern(ds_name): + factory_ds_by_type[ds_config.get("type", "DefaultDataset")].append( + ds_name + ) + else: + # resolve any factory datasets in the pipeline + for ds_name in default_ds: + matched_pattern = data_catalog._match_pattern( + data_catalog._dataset_patterns, ds_name + ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) + if matched_pattern: + ds_config_copy = copy.deepcopy( + data_catalog._dataset_patterns.get(matched_pattern) + or data_catalog._default_pattern.get(matched_pattern) + or {} + ) + + ds_config = data_catalog._resolve_config( + ds_name, matched_pattern, ds_config_copy + ) + factory_ds_by_type[ds_config["type"]].append(ds_name) default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values())) @@ -128,12 +151,11 @@ def _map_type_to_datasets( datasets of the specific type as a value. """ mapping = defaultdict(list) # type: ignore[var-annotated] - for dataset in datasets: - is_param = dataset.startswith("params:") or dataset == "parameters" - if not is_param: - ds_type = datasets_meta[dataset].__class__.__name__ - if dataset not in mapping[ds_type]: - mapping[ds_type].append(dataset) + for dataset_name in datasets: + if not is_parameter(dataset_name): + ds_type = datasets_meta[dataset_name].__class__.__name__ + if dataset_name not in mapping[ds_type]: + mapping[ds_type].append(dataset_name) return mapping @@ -147,8 +169,13 @@ def _map_type_to_datasets( required=True, help="Name of a pipeline.", ) +@click.option( + "--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP +) @click.pass_obj -def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> None: +def create_catalog( + metadata: ProjectMetadata, pipeline_name: str, env: str, new_catalog: bool +) -> None: """Create Data Catalog YAML configuration with missing datasets. Add ``MemoryDataset`` datasets to Data Catalog YAML configuration @@ -161,6 +188,7 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> N env = env or "base" session = _create_session(metadata.package_name, env=env) context = session.load_context() + catalog = context.catalog_new if new_catalog else context.catalog pipeline = pipelines.get(pipeline_name) @@ -170,20 +198,16 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> N f"'{pipeline_name}' pipeline not found! Existing pipelines: {existing_pipelines}" ) - pipe_datasets = { - ds_name - for ds_name in pipeline.datasets() - if not ds_name.startswith("params:") and ds_name != "parameters" + pipeline_datasets = { + ds_name for ds_name in pipeline.datasets() if not is_parameter(ds_name) } catalog_datasets = { - ds_name - for ds_name in context.catalog._datasets.keys() - if not ds_name.startswith("params:") and ds_name != "parameters" + ds_name for ds_name in catalog.list() if not is_parameter(ds_name) } # Datasets that are missing in Data Catalog - missing_ds = sorted(pipe_datasets - catalog_datasets) + missing_ds = sorted(pipeline_datasets - catalog_datasets) if missing_ds: catalog_path = ( context.project_path @@ -216,17 +240,27 @@ def _add_missing_datasets_to_catalog(missing_ds: list[str], catalog_path: Path) @catalog.command("rank") @env_option @click.pass_obj -def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None: +@click.option( + "--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP +) +def rank_catalog_factories( + metadata: ProjectMetadata, env: str, new_catalog: bool +) -> None: """List all dataset factories in the catalog, ranked by priority by which they are matched.""" session = _create_session(metadata.package_name, env=env) context = session.load_context() - catalog_factories = { - **context.catalog._dataset_patterns, - **context.catalog._default_pattern, - } + if new_catalog: + catalog_factories = context.catalog_config_resolver.list_patterns() + else: + catalog_factories = list( + { + **context.catalog._dataset_patterns, + **context.catalog._default_pattern, + }.keys() + ) if catalog_factories: - click.echo(yaml.dump(list(catalog_factories.keys()))) + click.echo(yaml.dump(catalog_factories)) else: click.echo("There are no dataset factories in the catalog.") @@ -234,51 +268,66 @@ def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None: @catalog.command("resolve") @env_option @click.pass_obj -def resolve_patterns(metadata: ProjectMetadata, env: str) -> None: +@click.option( + "--new_catalog", "-n", "new_catalog", is_flag=True, help=NEW_CATALOG_ARG_HELP +) +def resolve_patterns(metadata: ProjectMetadata, env: str, new_catalog: bool) -> None: """Resolve catalog factories against pipeline datasets. Note that this command is runner agnostic and thus won't take into account any default dataset creation defined in the runner.""" session = _create_session(metadata.package_name, env=env) context = session.load_context() - catalog_config = context.config_loader["catalog"] - credentials_config = context.config_loader.get("credentials", None) - data_catalog = DataCatalog.from_config( - catalog=catalog_config, credentials=credentials_config - ) + catalog_config_resolver = None + if new_catalog: + data_catalog = context.catalog_new + catalog_config_resolver = context.catalog_config_resolver + explicit_datasets = { + ds_name: ds_config + for ds_name, ds_config in data_catalog.config.items() + if not is_parameter(ds_name) + } + else: + data_catalog = context.catalog + catalog_config = context.config_loader["catalog"] - explicit_datasets = { - ds_name: ds_config - for ds_name, ds_config in catalog_config.items() - if not data_catalog._is_pattern(ds_name) - } + explicit_datasets = { + ds_name: ds_config + for ds_name, ds_config in catalog_config.items() + if not data_catalog._is_pattern(ds_name) + } target_pipelines = pipelines.keys() - datasets = set() + pipeline_datasets = set() for pipe in target_pipelines: pl_obj = pipelines.get(pipe) if pl_obj: - datasets.update(pl_obj.datasets()) + pipeline_datasets.update(pl_obj.datasets()) - for ds_name in datasets: - is_param = ds_name.startswith("params:") or ds_name == "parameters" - if ds_name in explicit_datasets or is_param: + for ds_name in pipeline_datasets: + if ds_name in explicit_datasets or is_parameter(ds_name): continue - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} - ) + if new_catalog: + ds_config = catalog_config_resolver.resolve_dataset_patterns(ds_name) + else: + ds_config = None + matched_pattern = data_catalog._match_pattern( + data_catalog._dataset_patterns, ds_name + ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) + if matched_pattern: + ds_config_copy = copy.deepcopy( + data_catalog._dataset_patterns.get(matched_pattern) + or data_catalog._default_pattern.get(matched_pattern) + or {} + ) + ds_config = data_catalog._resolve_config( + ds_name, matched_pattern, ds_config_copy + ) - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy - ) + # Exclude MemoryDatasets not set in the catalog explicitly + if ds_config is not None: explicit_datasets[ds_name] = ds_config secho(yaml.dump(explicit_datasets)) diff --git a/kedro/io/core.py b/kedro/io/core.py index f3975c9c3c..1f9dd8ff48 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -871,3 +871,8 @@ def validate_on_forbidden_chars(**kwargs: Any) -> None: raise DatasetError( f"Neither white-space nor semicolon are allowed in '{key}'." ) + + +def is_parameter(dataset_name: str) -> bool: + """Check if dataset is a parameter.""" + return dataset_name.startswith("params:") or dataset_name == "parameters"