From 531291338c284e3fdffcf1620c8d7222ddcc34bd Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Tue, 23 Jul 2024 16:02:34 +0545 Subject: [PATCH] don't depend on datachain from PATH to exec processes (#118) --- src/datachain/cli.py | 18 +++++++----------- src/datachain/query/dataset.py | 6 +++--- src/datachain/utils.py | 6 ++++++ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/datachain/cli.py b/src/datachain/cli.py index 531817737..f25eb9afe 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -3,7 +3,7 @@ import shlex import sys import traceback -from argparse import SUPPRESS, Action, ArgumentParser, ArgumentTypeError, Namespace +from argparse import Action, ArgumentParser, ArgumentTypeError, Namespace from collections.abc import Iterable, Iterator, Mapping, Sequence from importlib.metadata import PackageNotFoundError, version from itertools import chain @@ -106,10 +106,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 parser = ArgumentParser( description="DataChain: Wrangle unstructured AI data at scale", prog="datachain" ) - parser.add_argument("-V", "--version", action="version", version=__version__) - parser.add_argument("--internal-run-udf", action="store_true", help=SUPPRESS) - parser.add_argument("--internal-run-udf-worker", action="store_true", help=SUPPRESS) parent_parser = ArgumentParser(add_help=False) parent_parser.add_argument( @@ -155,6 +152,7 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 metavar="command", dest="command", help=f"Use `{parser.prog} command --help` for command-specific help.", + required=True, ) parse_cp = subp.add_parser( "cp", parents=[parent_parser], description="Copy data files from the cloud" @@ -556,6 +554,8 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915 "gc", parents=[parent_parser], description="Garbage collect temporary tables" ) + subp.add_parser("internal-run-udf", parents=[parent_parser]) + subp.add_parser("internal-run-udf-worker", parents=[parent_parser]) add_completion_parser(subp, [parent_parser]) return parser @@ -910,27 +910,23 @@ def completion(shell: str) -> str: ) -def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0911, PLR0912, PLR0915 +def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR0915 # Required for Windows multiprocessing support freeze_support() parser = get_parser() args = parser.parse_args(argv) - if args.internal_run_udf: + if args.command == "internal-run-udf": from datachain.query.dispatch import udf_entrypoint return udf_entrypoint() - if args.internal_run_udf_worker: + if args.command == "internal-run-udf-worker": from datachain.query.dispatch import udf_worker_entrypoint return udf_worker_entrypoint() - if args.command is None: - parser.print_help() - return 1 - from .catalog import get_catalog logger.addHandler(logging.StreamHandler()) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index be1e1995d..f32f71332 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -54,6 +54,7 @@ batched, determine_processes, filtered_cloudpickle_dumps, + get_datachain_executable, ) from .metrics import metrics @@ -507,13 +508,12 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: # Run the UDFDispatcher in another process to avoid needing # if __name__ == '__main__': in user scripts - datachain_exec_path = os.environ.get("DATACHAIN_EXEC_PATH", "datachain") - + exec_cmd = get_datachain_executable() envs = dict(os.environ) envs.update({"PYTHONPATH": os.getcwd()}) process_data = filtered_cloudpickle_dumps(udf_info) result = subprocess.run( # noqa: S603 - [datachain_exec_path, "--internal-run-udf"], + [*exec_cmd, "internal-run-udf"], input=process_data, check=False, env=envs, diff --git a/src/datachain/utils.py b/src/datachain/utils.py index a2fc3e0c0..6f4a55ee2 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -427,3 +427,9 @@ def filtered_cloudpickle_dumps(obj: Any) -> bytes: for model_class, namespace in model_namespaces.items(): # Restore original __pydantic_parent_namespace__ locally. model_class.__pydantic_parent_namespace__ = namespace + + +def get_datachain_executable() -> list[str]: + if datachain_exec_path := os.getenv("DATACHAIN_EXEC_PATH"): + return [datachain_exec_path] + return [sys.executable, "-m", "datachain"]