diff --git a/dlt/cli/_dlt.py b/dlt/cli/_dlt.py index 2332c0286c..c0d9918ef3 100644 --- a/dlt/cli/_dlt.py +++ b/dlt/cli/_dlt.py @@ -21,6 +21,8 @@ DEFAULT_VERIFIED_SOURCES_REPO, ) from dlt.cli.pipeline_command import pipeline_command, DLT_PIPELINE_COMMAND_DOCS_URL +from dlt.cli.run_command import run_pipeline_command +from dlt.common.cli.runner.help import run_args_help from dlt.cli.telemetry_command import ( DLT_TELEMETRY_DOCS_URL, change_telemetry_status_command, @@ -135,7 +137,11 @@ def deploy_command_wrapper( @utils.track_command("pipeline", True, "operation") def pipeline_command_wrapper( - operation: str, pipeline_name: str, pipelines_dir: str, verbosity: int, **command_kwargs: Any + operation: str, + pipeline_name: str, + pipelines_dir: str, + verbosity: int, + **command_kwargs: Any, ) -> int: try: pipeline_command(operation, pipeline_name, pipelines_dir, verbosity, **command_kwargs) @@ -205,7 +211,11 @@ def __init__( help: str = None, # noqa ) -> None: super(TelemetryAction, self).__init__( - option_strings=option_strings, dest=dest, default=default, nargs=0, help=help + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help, ) def __call__( @@ -230,7 +240,11 @@ def __init__( help: str = None, # noqa ) -> None: super(NonInteractiveAction, self).__init__( - option_strings=option_strings, dest=dest, default=default, nargs=0, help=help + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help, ) def __call__( @@ -252,7 +266,11 @@ def __init__( help: str = None, # noqa ) -> None: super(DebugAction, self).__init__( - option_strings=option_strings, dest=dest, default=default, nargs=0, help=help + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help, ) def __call__( @@ -273,7 +291,9 @@ def main() -> int: formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--version", action="version", version="%(prog)s {version}".format(version=__version__) + "--version", + action="version", + version="%(prog)s {version}".format(version=__version__), ) parser.add_argument( "--disable-telemetry", @@ -366,7 +386,9 @@ def main() -> int: "deploy", help="Creates a deployment package for a selected pipeline script" ) deploy_cmd.add_argument( - "pipeline_script_path", metavar="pipeline-script-path", help="Path to a pipeline script" + "pipeline_script_path", + metavar="pipeline-script-path", + help="Path to a pipeline script", ) deploy_sub_parsers = deploy_cmd.add_subparsers(dest="deployment_method") @@ -423,25 +445,37 @@ def main() -> int: ) deploy_cmd.add_argument("--help", "-h", nargs="?", const=True) deploy_cmd.add_argument( - "pipeline_script_path", metavar="pipeline-script-path", nargs=argparse.REMAINDER + "pipeline_script_path", + metavar="pipeline-script-path", + nargs=argparse.REMAINDER, ) schema = subparsers.add_parser("schema", help="Shows, converts and upgrades schemas") schema.add_argument( - "file", help="Schema file name, in yaml or json format, will autodetect based on extension" + "file", + help="Schema file name, in yaml or json format, will autodetect based on extension", ) schema.add_argument( - "--format", choices=["json", "yaml"], default="yaml", help="Display schema in this format" + "--format", + choices=["json", "yaml"], + default="yaml", + help="Display schema in this format", ) schema.add_argument( - "--remove-defaults", action="store_true", help="Does not show default hint values" + "--remove-defaults", + action="store_true", + help="Does not show default hint values", ) pipe_cmd = subparsers.add_parser( "pipeline", help="Operations on pipelines that were ran locally" ) pipe_cmd.add_argument( - "--list-pipelines", "-l", default=False, action="store_true", help="List local pipelines" + "--list-pipelines", + "-l", + default=False, + action="store_true", + help="List local pipelines", ) pipe_cmd.add_argument( "--hot-reload", @@ -464,10 +498,12 @@ def main() -> int: pipe_cmd_sync_parent = argparse.ArgumentParser(add_help=False) pipe_cmd_sync_parent.add_argument( - "--destination", help="Sync from this destination when local pipeline state is missing." + "--destination", + help="Sync from this destination when local pipeline state is missing.", ) pipe_cmd_sync_parent.add_argument( - "--dataset-name", help="Dataset name to sync from when local pipeline state is missing." + "--dataset-name", + help="Dataset name to sync from when local pipeline state is missing.", ) pipeline_subparsers.add_parser( @@ -510,7 +546,9 @@ def main() -> int: help="Display schema in this format", ) pipe_cmd_schema.add_argument( - "--remove-defaults", action="store_true", help="Does not show default hint values" + "--remove-defaults", + action="store_true", + help="Does not show default hint values", ) pipe_cmd_drop = pipeline_subparsers.add_parser( @@ -552,7 +590,8 @@ def main() -> int: ) pipe_cmd_package = pipeline_subparsers.add_parser( - "load-package", help="Displays information on load package, use -v or -vv for more info" + "load-package", + help="Displays information on load package, use -v or -vv for more info", ) pipe_cmd_package.add_argument( "load_id", @@ -563,6 +602,48 @@ def main() -> int: subparsers.add_parser("telemetry", help="Shows telemetry status") + # CLI pipeline runner + run_cmd = subparsers.add_parser( + "run", + help="Run pipelines in a given directory", + formatter_class=argparse.RawTextHelpFormatter, + ) + + run_cmd.add_argument( + "module", + help="Path to module or python file with pipelines", + ) + + run_cmd.add_argument( + "pipeline_name", + type=str, + nargs="?", + help="Pipeline name", + ) + + run_cmd.add_argument( + "source", + type=str, + nargs="?", + help="Source or resource name", + ) + + # TODO: enable once pipeline.run with full refresh option is available + # run_cmd.add_argument( + # "--full-refresh", + # action="store_true", + # default=False, + # help="When used pipeline will run in full-refresh mode", + # ) + + run_cmd.add_argument( + "--args", + "-a", + nargs="+", + default=[], + help=run_args_help, + ) + args = parser.parse_args() if Venv.is_virtual_env() and not Venv.is_venv_activated(): @@ -575,6 +656,13 @@ def main() -> int: if args.command == "schema": return schema_command_wrapper(args.file, args.format, args.remove_defaults) + elif args.command == "run": + return run_pipeline_command( + args.module, + args.pipeline_name, + args.source, + args.args, + ) elif args.command == "pipeline": if args.list_pipelines: return pipeline_command_wrapper("list", "-", args.pipelines_dir, args.verbosity) @@ -596,7 +684,11 @@ def main() -> int: return -1 else: return init_command_wrapper( - args.source, args.destination, args.generic, args.location, args.branch + args.source, + args.destination, + args.generic, + args.location, + args.branch, ) elif args.command == "deploy": try: diff --git a/dlt/cli/echo.py b/dlt/cli/echo.py index bd9cf24f64..1657ab501d 100644 --- a/dlt/cli/echo.py +++ b/dlt/cli/echo.py @@ -29,14 +29,26 @@ def bold(msg: str) -> str: return click.style(msg, bold=True, reset=True) +def info_style(msg: str) -> str: + return click.style(msg, fg="blue", reset=True) + + def warning_style(msg: str) -> str: return click.style(msg, fg="yellow", reset=True) +def error_style(msg: str) -> str: + return click.style(msg, fg="red", reset=True) + + def error(msg: str) -> None: click.secho("ERROR: " + msg, fg="red") +def info(msg: str) -> None: + click.secho("INFO: " + msg, fg="blue") + + def warning(msg: str) -> None: click.secho("WARNING: " + msg, fg="yellow") diff --git a/dlt/cli/run_command.py b/dlt/cli/run_command.py new file mode 100644 index 0000000000..0575f4e39e --- /dev/null +++ b/dlt/cli/run_command.py @@ -0,0 +1,51 @@ +import os +from typing import List, Optional + +from dlt.cli import echo as fmt +from dlt.cli.utils import track_command +from dlt.common.cli.runner.errors import FriendlyExit, PreflightError, RunnerError +from dlt.common.cli.runner.runner import PipelineRunner +from dlt.common.cli.runner.types import RunnerParams +from dlt.pipeline.exceptions import PipelineStepFailed + + +@track_command("run", False) +def run_pipeline_command( + module: str, + pipeline_name: Optional[str] = None, + source_name: Optional[str] = None, + args: Optional[List[str]] = None, +) -> int: + """Run the given module if any pipeline and sources or resources exist in it + + Args: + module (str): path to python module or file with dlt artifacts + pipeline_name (Optional[str]): Pipeline name + source_name (Optional[str]): Source or resource name + args (Optiona[List[str]]): List of arguments to `pipeline.run` method + + Returns: + (int): exit code + """ + params = RunnerParams( + module, + current_dir=os.getcwd(), + pipeline_name=pipeline_name, + source_name=source_name, + args=args, + ) + + try: + with PipelineRunner(params=params) as runner: + load_info = runner.run() + fmt.echo("") + fmt.echo(load_info) + except PipelineStepFailed: + raise + except RunnerError as ex: + fmt.echo(fmt.error_style(ex.message)) + return -1 + except (FriendlyExit, PreflightError): + fmt.info("Stopping...") + + return 0 diff --git a/dlt/common/cli/__init__.py b/dlt/common/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dlt/common/cli/runner/__init__.py b/dlt/common/cli/runner/__init__.py new file mode 100644 index 0000000000..eafb741856 --- /dev/null +++ b/dlt/common/cli/runner/__init__.py @@ -0,0 +1,10 @@ +from dlt.common.cli.runner.pipeline_script import PipelineScript +from dlt.common.cli.runner.runner import PipelineRunner +from dlt.common.cli.runner.types import PipelineMembers, RunnerParams + +__all__ = ( + "PipelineRunner", + "RunnerParams", + "PipelineMembers", + "PipelineScript", +) diff --git a/dlt/common/cli/runner/errors.py b/dlt/common/cli/runner/errors.py new file mode 100644 index 0000000000..1816a3c7aa --- /dev/null +++ b/dlt/common/cli/runner/errors.py @@ -0,0 +1,12 @@ +class RunnerError(Exception): + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) + + +class FriendlyExit(Exception): + pass + + +class PreflightError(Exception): + pass diff --git a/dlt/common/cli/runner/help.py b/dlt/common/cli/runner/help.py new file mode 100644 index 0000000000..1847c8d5a0 --- /dev/null +++ b/dlt/common/cli/runner/help.py @@ -0,0 +1,24 @@ +from typing_extensions import get_args +from dlt.cli import echo as fmt +from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.common.schema.typing import TSchemaEvolutionMode, TWriteDisposition + + +supported_formats = "|".join(get_args(TLoaderFileFormat)) +supported_evolution_modes = "|".join(get_args(TSchemaEvolutionMode)) +supported_write_disposition = "|".join(get_args(TWriteDisposition)) + +run_args_help = "".join( + ( + "Supported arguments passed to pipeline.run are", + fmt.info_style("\n - destination=string"), + fmt.info_style("\n - staging=string,"), + fmt.info_style("\n - credentials=string,"), + fmt.info_style("\n - table_name=string,"), + fmt.info_style(f"\n - write_disposition={supported_write_disposition},"), + fmt.info_style("\n - dataset_name=string,"), + fmt.info_style("\n - primary_key=string,"), + fmt.info_style(f"\n - schema_contract={supported_evolution_modes},"), + fmt.info_style(f"\n - loader_file_format={supported_formats}"), + ) +) diff --git a/dlt/common/cli/runner/inquirer.py b/dlt/common/cli/runner/inquirer.py new file mode 100644 index 0000000000..1143f1b677 --- /dev/null +++ b/dlt/common/cli/runner/inquirer.py @@ -0,0 +1,175 @@ +import os +import typing as t +import textwrap as tw +import dlt + +from dlt.cli import echo as fmt +from dlt.common.cli.runner.errors import FriendlyExit, PreflightError, RunnerError +from dlt.common.cli.runner.types import PipelineMembers, RunnerParams +from dlt.sources import DltResource, DltSource + + +dot_dlt = ".dlt" +select_message = """Please select your %s: +%s +""" + + +def make_select_options( + select_what: str, items: t.Dict[str, t.Any] +) -> t.Tuple[str, t.List[str], t.List[str]]: + options = [] + option_values = [] + choice_message = "" + for idx, name in enumerate(items.keys()): + choice_message += f"{idx} - {name}\n" + options.append(str(idx)) + option_values.append(name) + + choice_message += "n - exit\n" + options.append("n") + return select_message % (select_what, choice_message), options, option_values + + +class Inquirer: + """This class does pre flight checks required to run the pipeline. + Also handles user input to allow users to select pipeline, resources and sources. + """ + + def __init__(self, params: RunnerParams, pipeline_members: PipelineMembers) -> None: + self.params = params + self.pipelines = pipeline_members.get("pipelines") + self.sources = pipeline_members.get("sources") + + def maybe_ask(self) -> t.Tuple[dlt.Pipeline, t.Union[DltResource, DltSource]]: + """Shows prompts to select pipeline, resources and sources + + Returns: + (DltResource, DltSource): a pair of pipeline and resources and sources + """ + self.preflight_checks() + pipeline_name = self.get_pipeline_name() + pipeline = self.pipelines[pipeline_name] + + source_name = self.get_source_name() + resource = self.sources[source_name] + if isinstance(resource, DltResource): + label = "Resource" + else: + label = "Source" + + real_source_name = "" + if resource.name != source_name: + real_source_name = f" ({resource.name})" + + real_pipeline_name = "" + if pipeline.pipeline_name != pipeline_name: + real_pipeline_name = f" ({pipeline.pipeline_name})" + + fmt.echo("\nPipeline: " + fmt.style(pipeline_name + real_pipeline_name, fg="blue")) + + fmt.echo(f"{label}: " + fmt.style(source_name + real_source_name, fg="blue")) + return pipeline, resource + + def get_pipeline_name(self) -> str: + if self.params.pipeline_name: + return self.params.pipeline_name + + if len(self.pipelines) > 1: + message, options, values = make_select_options("pipeline", self.pipelines) + choice = self.ask(message, options, default="n") + pipeline_name = values[choice] + return pipeline_name + + pipeline_name, _ = next(iter(self.pipelines.items())) + return pipeline_name + + def get_source_name(self) -> str: + if self.params.source_name: + return self.params.source_name + + if len(self.sources) > 1: + message, options, values = make_select_options("resource or source", self.sources) + choice = self.ask(message, options, default="n") + source_name = values[choice] + return source_name + + source_name, _ = next(iter(self.sources.items())) + return source_name + + def ask(self, message: str, options: t.List[str], default: t.Optional[str] = None) -> int: + choice = fmt.prompt(message, options, default=default) + if choice == "n": + raise FriendlyExit() + + return int(choice) + + def preflight_checks(self) -> None: + if pipeline_name := self.params.pipeline_name: + if pipeline_name not in self.pipelines: + fmt.echo( + fmt.error_style( + f"Pipeline {pipeline_name} has not been found in pipeline script" + "You can choose one of: " + + ", ".join(self.pipelines.keys()) + ) + ) + raise PreflightError() + + if source_name := self.params.source_name: + if source_name not in self.sources: + fmt.echo( + fmt.error_style( + f"Source or resouce with name: {source_name} has not been found in pipeline" + " script. \n You can choose one of: " + + ", ".join(self.sources.keys()) + ) + ) + raise PreflightError() + + if self.params.current_dir != self.params.pipeline_workdir: + fmt.echo(f"\nCurrent workdir: {fmt.style(self.params.current_dir, fg='blue')}") + fmt.echo(f"Pipeline workdir: {fmt.style(self.params.pipeline_workdir, fg='blue')}\n") + + fmt.echo( + fmt.warning_style( + "Current working directory is different from the " + f"pipeline script {self.params.pipeline_workdir}" + ) + ) + + has_cwd_config = self.has_dlt_config(self.params.current_dir) + has_pipeline_config = self.has_dlt_config(self.params.pipeline_workdir) + if has_cwd_config and has_pipeline_config: + message = tw.dedent( + f"""Using {dot_dlt} in current directory, if you intended to use """ + f"""{self.params.pipeline_workdir}/{dot_dlt}, please change your current directory.""", + ) + fmt.echo(fmt.warning_style(message)) + elif not has_cwd_config and has_pipeline_config: + fmt.echo( + fmt.error_style( + f"{dot_dlt} is missing in current directory but exists in pipeline script's" + " directory" + ) + ) + fmt.echo( + fmt.info_style( + "If needed please change your current directory to" + f" {self.params.pipeline_workdir} and try again" + ) + ) + + def check_if_runnable(self) -> None: + if not self.pipelines: + raise RunnerError( + f"No pipeline instances found in pipeline script {self.params.script_path}" + ) + + if not self.sources: + raise RunnerError( + f"No source or resources found in pipeline script {self.params.script_path}" + ) + + def has_dlt_config(self, path: str) -> bool: + return os.path.exists(os.path.join(path, ".dlt")) diff --git a/dlt/common/cli/runner/pipeline_script.py b/dlt/common/cli/runner/pipeline_script.py new file mode 100644 index 0000000000..2009a9b603 --- /dev/null +++ b/dlt/common/cli/runner/pipeline_script.py @@ -0,0 +1,198 @@ +import os +import importlib +import typing as t +import sys + +from contextlib import contextmanager +from types import ModuleType +from typing_extensions import get_args + +import dlt + +from dlt.cli import echo as fmt +from dlt.cli.utils import parse_init_script +from dlt.common.cli.runner.errors import RunnerError +from dlt.common.cli.runner.types import PipelineMembers, RunnerParams +from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.common.pipeline import LoadInfo +from dlt.common.schema.typing import TSchemaEvolutionMode, TWriteDisposition +from dlt.sources import DltResource, DltSource + + +COMMAND_NAME: t.Final[str] = "run" + + +class PipelineScript: + """Handles pipeline source code and prepares it to run + + Things it does + + 1. Reads the source code, + 2. Stubs all pipeline.run calls, + 3. Creates a temporary module in the location of pipeline file, + 4. Imports rewritten module, + 5. Prepares module name, + 6. Exposes ready to use module instance. + """ + + def __init__(self, params: RunnerParams) -> None: + self.params = params + """This means that user didn't specify pipeline name or resource and source name + + And we need to check if there is 1 pipeline and 1 resource or source to run right + away if there are multiple resources and sources then we need to provide a CLI prompt + """ + self.workdir = os.path.dirname(os.path.abspath(params.script_path)) + """Directory in which pipeline script lives""" + + self.has_pipeline_auto_runs: bool = False + self.supported_formats = get_args(TLoaderFileFormat) + self.supported_evolution_modes = get_args(TSchemaEvolutionMode) + self.supported_write_disposition = get_args(TWriteDisposition) + + # Now we need to patch and store pipeline code + visitor = parse_init_script( + COMMAND_NAME, + self.script_contents, + self.module_name, + ) + self.source_code = visitor.source + + def load_module(self, script_path: str) -> ModuleType: + """Loads pipeline module from a given location""" + with self.expect_no_pipeline_runs(): + spec = importlib.util.spec_from_file_location(self.module_name, script_path) + module = importlib.util.module_from_spec(spec) + sys.modules[self.module_name] = module + spec.loader.exec_module(module) + if self.has_pipeline_auto_runs: + raise RunnerError( + "Please move all pipeline.run calls inside __main__ or remove them" + ) + + return module + + @contextmanager + def expect_no_pipeline_runs(self): # type: ignore[no-untyped-def] + """Monkey patch pipeline.run during module loading + Restore it once importing is done + """ + original_run = dlt.Pipeline.run + + def noop(*args, **kwargs) -> LoadInfo: # type: ignore + self.has_pipeline_auto_runs = True + + dlt.Pipeline.run = noop # type: ignore + + yield + + dlt.Pipeline.run = original_run # type: ignore + + @property + def pipeline_module(self) -> ModuleType: + self.module = self.load_module(self.params.script_path) + return self.module + + @property + def script_contents(self) -> str: + """Loads script contents""" + with open(self.params.script_path, encoding="utf-8") as fp: + return fp.read() + + @property + def module_name(self) -> str: + """Strips extension with path and returns filename as modulename""" + module_name = self.params.script_path.split(os.sep)[-1] + if module_name.endswith(".py"): + module_name = module_name[:-3] + + return module_name + + @property + def pipeline_members(self) -> PipelineMembers: + """Inspect the module and return pipelines with resources and sources + We populate sources, pipelines with relevant instances by their variable name. + + Resources and sources must be initialized and bound beforehand. + """ + members: PipelineMembers = { + "pipelines": {}, + "sources": {}, + } + for name, value in self.pipeline_module.__dict__.items(): + # skip modules and private stuff + if isinstance(value, ModuleType) or name.startswith("_"): + continue + + if isinstance(value, dlt.Pipeline): + members["pipelines"][name] = value + + if isinstance(value, DltSource): + members["sources"][name] = value + + if isinstance(value, DltResource): + if value._args_bound: + members["sources"][name] = value + else: + fmt.echo(fmt.info_style(f"Resource: {value.name} is not bound, skipping.")) + + return members + + @property + def run_arguments(self) -> t.Dict[str, str]: + run_options: t.Dict[str, str] = {} + if not self.params.args: + return run_options + + for arg in self.params.args: + arg_name, value = arg.split("=") + run_options[arg_name] = value + + return run_options + + def validate_arguments(self) -> None: + """Validates and checks""" + supported_args = { + "destination", + "staging", + "credentials", + "table_name", + "write_disposition", + "dataset_name", + "primary_key", + "schema_contract", + "loader_file_format", + } + errors = [] + arguments = self.run_arguments + for arg_name, _ in arguments.items(): + if arg_name not in supported_args: + errors.append(f"Invalid argument {arg_name}") + + if ( + "write_disposition" in arguments + and arguments["write_disposition"] not in self.supported_write_disposition + ): + errors.append( + "Invalid write disposition, select one of" + f" {'|'.join(self.supported_write_disposition)}" + ) + if ( + "loader_file_format" in arguments + and arguments["loader_file_format"] not in self.supported_formats + ): + errors.append( + f"Invalid loader file format, select one of {'|'.join(self.supported_formats)}" + ) + if ( + "schema_contract" in arguments + and arguments["schema_contract"] not in self.supported_evolution_modes + ): + errors.append( + "Invalid schema_contract mode, select one of" + f" {'|'.join(self.supported_evolution_modes)}" + ) + + if errors: + all_errors = "\n".join(errors) + raise RunnerError(f"One or more arguments are invalid:\n{all_errors}") diff --git a/dlt/common/cli/runner/runner.py b/dlt/common/cli/runner/runner.py new file mode 100644 index 0000000000..dc99c41e58 --- /dev/null +++ b/dlt/common/cli/runner/runner.py @@ -0,0 +1,34 @@ +import typing as t + +from typing_extensions import Self +from types import TracebackType + +from dlt.common.pipeline import LoadInfo +from dlt.common.cli.runner.inquirer import Inquirer +from dlt.common.cli.runner.pipeline_script import PipelineScript +from dlt.common.cli.runner.types import RunnerParams + + +class PipelineRunner: + def __init__(self, params: RunnerParams) -> None: + self.params = params + self.script = PipelineScript(params) + self.script.validate_arguments() + self.params.pipeline_workdir = self.script.workdir + self.inquirer = Inquirer(self.params, self.script.pipeline_members) + + def __enter__(self) -> Self: + self.pipeline, self.resource = self.inquirer.maybe_ask() + return self + + def __exit__( + self, + exc_type: t.Optional[t.Type[BaseException]] = None, + exc_value: t.Optional[BaseException] = None, + traceback: t.Optional[TracebackType] = None, + ) -> None: + pass + + def run(self) -> LoadInfo: + load_info = self.pipeline.run(self.resource, **self.script.run_arguments) # type: ignore[arg-type] + return load_info diff --git a/dlt/common/cli/runner/types.py b/dlt/common/cli/runner/types.py new file mode 100644 index 0000000000..df008085a0 --- /dev/null +++ b/dlt/common/cli/runner/types.py @@ -0,0 +1,22 @@ +import typing as t + +from dataclasses import dataclass + +import dlt + +from dlt.sources import DltResource, DltSource + + +@dataclass +class RunnerParams: + script_path: str + current_dir: str + pipeline_workdir: t.Optional[str] = None + pipeline_name: t.Optional[str] = None + source_name: t.Optional[str] = None + args: t.List[str] = None + + +class PipelineMembers(t.TypedDict): + pipelines: t.Dict[str, dlt.Pipeline] + sources: t.Dict[str, t.Union[DltResource, DltSource]] diff --git a/dlt/common/reflection/utils.py b/dlt/common/reflection/utils.py index 9bd3cb6775..3ce8ab8ba6 100644 --- a/dlt/common/reflection/utils.py +++ b/dlt/common/reflection/utils.py @@ -78,8 +78,11 @@ def creates_func_def_name_node(func_def: ast.FunctionDef, source_lines: Sequence def rewrite_python_script( source_script_lines: List[str], transformed_nodes: List[Tuple[ast.AST, ast.AST]] ) -> List[str]: - """Replaces all the nodes present in `transformed_nodes` in the `script_lines`. The `transformed_nodes` is a tuple where the first element - is must be a node with full location information created out of `script_lines`""" + """Replaces all the nodes present in `transformed_nodes` in the `script_lines`. + + The `transformed_nodes` is a tuple where the first element must be a node with + full location information created out of `script_lines` + """ script_lines: List[str] = [] last_line = -1 last_offset = -1 diff --git a/dlt/reflection/script_visitor.py b/dlt/reflection/script_visitor.py index 52b19fe031..90129395b1 100644 --- a/dlt/reflection/script_visitor.py +++ b/dlt/reflection/script_visitor.py @@ -2,7 +2,8 @@ import ast import astunparse from ast import NodeVisitor -from typing import Any, Dict, List +from collections import defaultdict +from typing import Any, Dict, List, Tuple from dlt.common.reflection.utils import find_outer_func_def @@ -19,6 +20,7 @@ def __init__(self, source: str): # self.source_aliases: Dict[str, str] = {} self.is_destination_imported: bool = False self.known_calls: Dict[str, List[inspect.BoundArguments]] = {} + self.known_calls_with_nodes: Dict[str, List[Tuple[inspect.BoundArguments, ast.AST]]] = defaultdict(list) self.known_sources: Dict[str, ast.FunctionDef] = {} self.known_source_calls: Dict[str, List[ast.Call]] = {} self.known_resources: Dict[str, ast.FunctionDef] = {} @@ -104,6 +106,7 @@ def visit_Call(self, node: ast.Call) -> Any: # print(f"ALIAS: {alias_name} of {self.func_aliases.get(alias_name)} with {bound_args}") fun_calls = self.known_calls.setdefault(fn, []) fun_calls.append(bound_args) + self.known_calls_with_nodes[fn].append((bound_args, node)) except TypeError: # skip the signature pass diff --git a/tests/cli/cases/cli_runner/.dlt/config.toml b/tests/cli/cases/cli_runner/.dlt/config.toml new file mode 100644 index 0000000000..54a4d376fa --- /dev/null +++ b/tests/cli/cases/cli_runner/.dlt/config.toml @@ -0,0 +1,4 @@ +restore_from_destination = false # block restoring from destination + +[runtime] +log_level = "WARNING" # the system log level of dlt diff --git a/tests/cli/cases/cli_runner/__init__.py b/tests/cli/cases/cli_runner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/cli/cases/cli_runner/pipeline.py b/tests/cli/cases/cli_runner/pipeline.py new file mode 100644 index 0000000000..5473744448 --- /dev/null +++ b/tests/cli/cases/cli_runner/pipeline.py @@ -0,0 +1,35 @@ +import dlt + + +@dlt.resource +def quads_resource(): + for idx in range(10): + yield {"id": idx, "num": idx**4} + + +@dlt.resource +def numbers_resource(): + for idx in range(10): + yield {"id": idx, "num": idx + 1} + + +@dlt.destination(loader_file_format="parquet") +def null_sink(_items, _table) -> None: + pass + + +quads_resource_instance = quads_resource() +numbers_resource_instance = numbers_resource() + +quads_pipeline = dlt.pipeline( + pipeline_name="numbers_quadruples_pipeline", + destination=null_sink, +) + +numbers_pipeline = dlt.pipeline( + pipeline_name="my_numbers_pipeline", + destination="duckdb", +) + +if __name__ == "__main__": + load_info = numbers_pipeline.run(numbers_resource(), schema=dlt.Schema("bobo-schema")) diff --git a/tests/cli/cases/cli_runner/pipeline_with_immediate_run.py b/tests/cli/cases/cli_runner/pipeline_with_immediate_run.py new file mode 100644 index 0000000000..50f3b5b6d3 --- /dev/null +++ b/tests/cli/cases/cli_runner/pipeline_with_immediate_run.py @@ -0,0 +1,21 @@ +import dlt + + +@dlt.resource +def numbers_resource(): + for idx in range(100): + yield {"id": idx, "num": idx} + + +@dlt.destination(loader_file_format="parquet") +def null_sink(_items, _table) -> None: + pass + + +numbers_pipeline = dlt.pipeline( + pipeline_name="numbers_pipeline", + destination=null_sink, +) + +load_info = numbers_pipeline.run() +print(load_info) diff --git a/tests/cli/test_run_command.py b/tests/cli/test_run_command.py new file mode 100644 index 0000000000..a379288655 --- /dev/null +++ b/tests/cli/test_run_command.py @@ -0,0 +1,182 @@ +import io +import contextlib + +import shutil +from typing import List +from unittest import mock + +import dlt +import pytest + +from dlt.cli import run_command + +from dlt.common.utils import set_working_dir +from tests.utils import TEST_STORAGE_ROOT, TESTS_ROOT + + +CLI_RUNNER_PIPELINES = TESTS_ROOT / "cli/cases/cli_runner" +TEST_PIPELINE = CLI_RUNNER_PIPELINES / "pipeline.py" +TEST_PIPELINE_WITH_IMMEDIATE_RUN = CLI_RUNNER_PIPELINES / "pipeline_with_immediate_run.py" + + +def test_run_command_happy_path_works_as_expected(): + # pipeline variable name + pipeline_name = "numbers_pipeline" + real_pipeline_name = "my_numbers_pipeline" + p = dlt.pipeline(pipeline_name=pipeline_name) + p._wipe_working_folder() + shutil.copytree(CLI_RUNNER_PIPELINES, TEST_STORAGE_ROOT, dirs_exist_ok=True) + + with io.StringIO() as buf, contextlib.redirect_stdout(buf), set_working_dir(TEST_STORAGE_ROOT): + run_command.run_pipeline_command( + str(TEST_PIPELINE), + pipeline_name, + "numbers_resource_instance", + ["write_disposition=merge", "loader_file_format=jsonl"], + ) + + output = buf.getvalue() + assert "Current working directory is different from the pipeline script" in output + assert "Pipeline: numbers_pipeline (my_numbers_pipeline)" in output + assert "Resource: numbers_resource_instance (numbers_resource)" in output + assert "Pipeline my_numbers_pipeline load step completed" in output + assert "contains no failed jobs" in output + + # Check if we can successfully attach to pipeline + pipeline = dlt.attach(real_pipeline_name) + assert pipeline.schema_names == ["my_numbers"] + + trace = pipeline.last_trace + assert trace is not None + assert len(trace.steps) == 4 + + step = trace.steps[0] + assert step.step == "extract" + + with pipeline.sql_client() as c: + with c.execute_query("select count(id) from numbers_resource") as cur: + row = list(cur.fetchall())[0] + assert row[0] == 10 + + +def test_run_command_fails_with_relevant_error_if_pipeline_resource_or_source_not_found(): + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + run_command.run_pipeline_command( + str(TEST_PIPELINE), + "pipeline_404", + "numbers_resource_instance", + ["write_disposition=merge", "loader_file_format=jsonl"], + ) + + output = buf.getvalue() + assert "Pipeline pipeline_404 has not been found in pipeline script" in output + assert "You can choose one of: quads_pipeline, numbers_pipeline" in output + + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + run_command.run_pipeline_command( + str(TEST_PIPELINE), + "numbers_pipeline", + "resource_404", + ["write_disposition=merge", "loader_file_format=jsonl"], + ) + + output = buf.getvalue() + assert ( + "Source or resouce with name: resource_404 has not been found in pipeline script." + in output + ) + assert "You can choose one of: quads_resource_instance, numbers_resource_instance" in output + + +def test_run_command_allows_selection_of_pipeline_source_or_resource(): + with mock.patch( + "dlt.common.cli.runner.inquirer.Inquirer.ask", return_value=0 + ) as mocked_ask, io.StringIO() as buf, contextlib.redirect_stdout(buf): + run_command.run_pipeline_command( + str(TEST_PIPELINE), + None, + None, + ["write_disposition=append", "loader_file_format=parquet"], + ) + + # expect 2 calls to Inquirer.ask + # first for pipeline selection + # second for reource or source + assert mocked_ask.call_count == 2 + + +def test_run_command_exits_if_exit_choice_selected(): + with mock.patch( + "dlt.common.cli.runner.inquirer.fmt.prompt", return_value="n" + ), io.StringIO() as buf, contextlib.redirect_stdout(buf): + run_command.run_pipeline_command( + str(TEST_PIPELINE), + None, + None, + ["write_disposition=append", "loader_file_format=parquet"], + ) + + output = buf.getvalue() + assert "Stopping..." in output + + +def test_run_command_exits_if_pipeline_run_calls_exist_at_the_top_level(): + with io.StringIO() as buf, contextlib.redirect_stdout(buf): + run_command.run_pipeline_command( + str(TEST_PIPELINE_WITH_IMMEDIATE_RUN), + None, + None, + ["write_disposition=append", "loader_file_format=parquet"], + ) + output = buf.getvalue() + assert "Please move all pipeline.run calls inside __main__ or remove them" in output + + +@pytest.mark.parametrize( + "arguments,expected_error_message", + [ + [ + ["write_disposition=merge", "loader_file_format=markdown"], + ( + "Invalid loader file format, select one of" + " jsonl|puae-jsonl|insert_values|sql|parquet|reference|arrow" + ), + ], + [ + ["write_disposition=guess", "loader_file_format=jsonl"], + "Invalid write disposition, select one of skip|append|replace|merge", + ], + [ + ["schema_contract=guess", "loader_file_format=jsonl"], + "Invalid schema_contract mode, select one of evolve|discard_value|freeze|discard_row", + ], + [ + ["yolo=yes", "moon=yes"], + "Invalid argument yolo\nInvalid argument moon", + ], + # Good case + [["write_disposition=append", "loader_file_format=parquet"], ""], + ], +) +def test_run_command_with_invalid_pipeline_run_arguments( + arguments: List[str], expected_error_message: str +): + pipeline_name = "numbers_pipeline" + p = dlt.pipeline(pipeline_name=pipeline_name) + p._wipe_working_folder() + shutil.copytree(CLI_RUNNER_PIPELINES, TEST_STORAGE_ROOT, dirs_exist_ok=True) + + with io.StringIO() as buf, contextlib.redirect_stdout(buf), set_working_dir(TEST_STORAGE_ROOT): + run_command.run_pipeline_command( + str(TEST_PIPELINE), + pipeline_name, + "numbers_resource_instance", + arguments, + ) + + output = buf.getvalue() + if expected_error_message: + assert expected_error_message in output + else: + # Check if we can attach to pipeline + dlt.attach("my_numbers_pipeline") diff --git a/tests/conftest.py b/tests/conftest.py index ab22d6ca7a..3600227f3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import os import dataclasses import logging +from pathlib import Path from typing import List # patch which providers to enable diff --git a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py index dd807260fe..02d86e039e 100644 --- a/tests/helpers/streamlit_tests/test_streamlit_show_resources.py +++ b/tests/helpers/streamlit_tests/test_streamlit_show_resources.py @@ -8,6 +8,7 @@ import os import sys from pathlib import Path +from unittest import mock import pytest @@ -152,13 +153,12 @@ def test_render_with_pipeline_with_different_pipeline_dirs(): def dummy_render(pipeline: dlt.Pipeline) -> None: pass - old_args = sys.argv[:] - with pytest.raises(CannotRestorePipelineException): - sys.argv = [*base_args, "/run/dlt"] + with pytest.raises(CannotRestorePipelineException), mock.patch.object( + sys, "argv", [*base_args, "/run/dlt"] + ): render_with_pipeline(dummy_render) - with pytest.raises(CannotRestorePipelineException): - sys.argv = [*base_args, "/tmp/dlt"] + with pytest.raises(CannotRestorePipelineException), mock.patch.object( + sys, "argv", [*base_args, "/tmp/dlt"] + ): render_with_pipeline(dummy_render) - - sys.argv = old_args diff --git a/tests/utils.py b/tests/utils.py index 00523486ea..aa5849cf72 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import multiprocessing import os +from pathlib import Path import platform import sys from os import environ @@ -29,6 +30,7 @@ from dlt.common.pipeline import PipelineContext, SupportsPipeline TEST_STORAGE_ROOT = "_storage" +TESTS_ROOT = Path(__file__).parent.absolute() # destination constants