From 075defb5e9d715365d467b1f97e01c57716dd1b0 Mon Sep 17 00:00:00 2001 From: Sultan Iman Date: Mon, 25 Mar 2024 16:49:59 +0100 Subject: [PATCH] Add simple validator for pipeline arguments --- dlt/common/cli/runner/pipeline_script.py | 53 ++++++++++++++++++++++++ dlt/common/cli/runner/runner.py | 1 + 2 files changed, 54 insertions(+) diff --git a/dlt/common/cli/runner/pipeline_script.py b/dlt/common/cli/runner/pipeline_script.py index a396d0d600..2009a9b603 100644 --- a/dlt/common/cli/runner/pipeline_script.py +++ b/dlt/common/cli/runner/pipeline_script.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from types import ModuleType +from typing_extensions import get_args import dlt @@ -12,7 +13,9 @@ from dlt.cli.utils import parse_init_script from dlt.common.cli.runner.errors import RunnerError from dlt.common.cli.runner.types import PipelineMembers, RunnerParams +from dlt.common.destination.capabilities import TLoaderFileFormat from dlt.common.pipeline import LoadInfo +from dlt.common.schema.typing import TSchemaEvolutionMode, TWriteDisposition from dlt.sources import DltResource, DltSource @@ -43,6 +46,9 @@ def __init__(self, params: RunnerParams) -> None: """Directory in which pipeline script lives""" self.has_pipeline_auto_runs: bool = False + self.supported_formats = get_args(TLoaderFileFormat) + self.supported_evolution_modes = get_args(TSchemaEvolutionMode) + self.supported_write_disposition = get_args(TWriteDisposition) # Now we need to patch and store pipeline code visitor = parse_init_script( @@ -143,3 +149,50 @@ def run_arguments(self) -> t.Dict[str, str]: run_options[arg_name] = value return run_options + + def validate_arguments(self) -> None: + """Validates and checks""" + supported_args = { + "destination", + "staging", + "credentials", + "table_name", + "write_disposition", + "dataset_name", + "primary_key", + "schema_contract", + "loader_file_format", + } + errors = [] + arguments = self.run_arguments + for arg_name, _ in arguments.items(): + if arg_name not in supported_args: + errors.append(f"Invalid argument {arg_name}") + + if ( + "write_disposition" in arguments + and arguments["write_disposition"] not in self.supported_write_disposition + ): + errors.append( + "Invalid write disposition, select one of" + f" {'|'.join(self.supported_write_disposition)}" + ) + if ( + "loader_file_format" in arguments + and arguments["loader_file_format"] not in self.supported_formats + ): + errors.append( + f"Invalid loader file format, select one of {'|'.join(self.supported_formats)}" + ) + if ( + "schema_contract" in arguments + and arguments["schema_contract"] not in self.supported_evolution_modes + ): + errors.append( + "Invalid schema_contract mode, select one of" + f" {'|'.join(self.supported_evolution_modes)}" + ) + + if errors: + all_errors = "\n".join(errors) + raise RunnerError(f"One or more arguments are invalid:\n{all_errors}") diff --git a/dlt/common/cli/runner/runner.py b/dlt/common/cli/runner/runner.py index 476ec0393c..dc99c41e58 100644 --- a/dlt/common/cli/runner/runner.py +++ b/dlt/common/cli/runner/runner.py @@ -13,6 +13,7 @@ class PipelineRunner: def __init__(self, params: RunnerParams) -> None: self.params = params self.script = PipelineScript(params) + self.script.validate_arguments() self.params.pipeline_workdir = self.script.workdir self.inquirer = Inquirer(self.params, self.script.pipeline_members)