From a8a6edeebcbcab5080eaf6aa4e036bce27067080 Mon Sep 17 00:00:00 2001 From: novahow Date: Thu, 21 Mar 2024 05:16:34 +0800 Subject: [PATCH] remove unused params and add test Signed-off-by: novahow --- flytekit/clis/sdk_in_container/run.py | 116 +++++++++++------- flytekit/clis/sdk_in_container/versions.py | 92 ++++++++------ .../unit/cli/pyflyte/test_versions.py | 47 +++++++ 3 files changed, 170 insertions(+), 85 deletions(-) create mode 100644 tests/flytekit/unit/cli/pyflyte/test_versions.py diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 82d47fbfc3c..d2ebfa7b5f5 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -60,13 +60,57 @@ class RunLevelComputedParams: @dataclass -class RunLevelParams(PyFlyteParams): +class RunBaseParams(PyFlyteParams): """ - This class is used to store the parameters that are used to run a workflow / task / launchplan. + This task contains basic parameters used in pyflyte run and pyflyte show-versions """ project: str = make_click_option_field(project_option) domain: str = make_click_option_field(domain_option) + limit: int = make_click_option_field( + click.Option( + param_decls=["--limit", "limit"], + required=False, + type=int, + default=50, + hidden=True, + show_default=True, + help="Use this to limit number of entities to fetch", + ) + ) + _remote: typing.Optional[FlyteRemote] = None + remote: bool = field(default=False, init=False) + + def remote_instance(self) -> FlyteRemote: + if self._remote is None: + data_upload_location = None + if self.is_remote: + data_upload_location = remote_fs.REMOTE_PLACEHOLDER + self._remote = get_plugin().get_remote(self.config_file, self.project, self.domain, data_upload_location) + return self._remote + + @property + def is_remote(self) -> bool: + return self.remote + + @classmethod + def from_dict(cls, d: typing.Dict[str, typing.Any]) -> "RunLevelParams": + return cls(**d) + + @classmethod + def options(cls) -> typing.List[click.Option]: + """ + Return the set of base parameters added to every pyflyte run workflow subcommand. + """ + return [get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata] + + +@dataclass +class RunLevelParams(RunBaseParams): + """ + This class is used to store the parameters that are used to run a workflow / task / launchplan. + """ + destination_dir: str = make_click_option_field( click.Option( param_decls=["--destination-dir", "destination_dir"], @@ -238,17 +282,6 @@ class RunLevelParams(PyFlyteParams): help="Whether to register and run the workflow on a Flyte deployment", ) ) - limit: int = make_click_option_field( - click.Option( - param_decls=["--limit", "limit"], - required=False, - type=int, - default=50, - hidden=True, - show_default=True, - help="Use this to limit number of entities to fetch", - ) - ) cluster_pool: str = make_click_option_field( click.Option( param_decls=["--cluster-pool", "cluster_pool"], @@ -259,30 +292,6 @@ class RunLevelParams(PyFlyteParams): ) ) computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams) - _remote: typing.Optional[FlyteRemote] = None - - def remote_instance(self) -> FlyteRemote: - if self._remote is None: - data_upload_location = None - if self.is_remote: - data_upload_location = remote_fs.REMOTE_PLACEHOLDER - self._remote = get_plugin().get_remote(self.config_file, self.project, self.domain, data_upload_location) - return self._remote - - @property - def is_remote(self) -> bool: - return self.remote - - @classmethod - def from_dict(cls, d: typing.Dict[str, typing.Any]) -> "RunLevelParams": - return cls(**d) - - @classmethod - def options(cls) -> typing.List[click.Option]: - """ - Return the set of base parameters added to every pyflyte run workflow subcommand. - """ - return [get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata] def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]: @@ -577,7 +586,7 @@ def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs) self._entity = None def _looped_fetch_entity( - self, entity_fetch_func: typing.Callable, run_level_params: RunLevelParams + self, entity_fetch_func: typing.Callable, run_level_params: RunBaseParams ) -> typing.Union[FlyteLaunchPlan, FlyteTask]: version_splits = self._entity_name.split(RemoteVersion.splitter) for _version_seg_len in range(len(version_splits)): @@ -599,8 +608,8 @@ def _looped_fetch_entity( def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]: if self._entity: return self._entity - run_level_params: RunLevelParams = ctx.obj - r = run_level_params.remote_instance() + run_level_params: RunBaseParams = ctx.obj + r: FlyteRemote = run_level_params.remote_instance() if self._launcher == self.LP_LAUNCHER: entity = self._looped_fetch_entity(r.fetch_launch_plan, run_level_params) else: @@ -679,10 +688,10 @@ class RemoteEntityGroup(click.RichGroup): WORKFLOW_COMMAND = "remote-workflow" TASK_COMMAND = "remote-task" - def __init__(self, command_name: str): + def __init__(self, command_name: str, h: str): super().__init__( name=command_name, - help=f"Retrieve {command_name} from a remote flyte instance and execute them.", + help=h, params=[ click.Option( ["--limit", "limit"], @@ -855,7 +864,7 @@ class RunCommand(click.RichGroup): A click command group for registering and executing flyte workflows & tasks in a file. """ - _run_params: typing.Type[RunLevelParams] = RunLevelParams + _run_params: typing.Type[RunBaseParams] = RunLevelParams def __init__(self, *args, **kwargs): if "params" not in kwargs: @@ -882,16 +891,29 @@ def get_command(self, ctx, filename): ctx.obj = {} if not isinstance(ctx.obj, self._run_params): params = {} - # NOTE: ctx.params: RunLevelParams params.update(ctx.params) params.update(ctx.obj) ctx.obj = self._run_params.from_dict(params) + entity_group_help_msg = ( + "Retrieve {command_name} from a remote flyte instance and execute them.\n" + "You may attach a version behind the {command_name} name to execute a specific version, \n" + "e.g. {command_name}:version1" + ) if filename == RemoteEntityGroup.LAUNCHPLAN_COMMAND: - return RemoteEntityGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND) + return RemoteEntityGroup( + RemoteEntityGroup.LAUNCHPLAN_COMMAND, + entity_group_help_msg.format(command_name=RemoteEntityGroup.LAUNCHPLAN_COMMAND), + ) elif filename == RemoteEntityGroup.WORKFLOW_COMMAND: - return RemoteEntityGroup(RemoteEntityGroup.WORKFLOW_COMMAND) + return RemoteEntityGroup( + RemoteEntityGroup.WORKFLOW_COMMAND, + entity_group_help_msg.format(command_name=RemoteEntityGroup.WORKFLOW_COMMAND), + ) elif filename == RemoteEntityGroup.TASK_COMMAND: - return RemoteEntityGroup(RemoteEntityGroup.TASK_COMMAND) + return RemoteEntityGroup( + RemoteEntityGroup.TASK_COMMAND, + entity_group_help_msg.format(command_name=RemoteEntityGroup.TASK_COMMAND), + ) return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}") diff --git a/flytekit/clis/sdk_in_container/versions.py b/flytekit/clis/sdk_in_container/versions.py index a78048f6a99..ed4ce5fd799 100644 --- a/flytekit/clis/sdk_in_container/versions.py +++ b/flytekit/clis/sdk_in_container/versions.py @@ -1,31 +1,47 @@ import typing +from dataclasses import dataclass import rich_click as click from click import Context, Parameter -from flytekit.clis.sdk_in_container.run import DynamicEntityLaunchCommand, RemoteEntityGroup, RunCommand, RunLevelParams +from flytekit.clis.sdk_in_container.run import DynamicEntityLaunchCommand, RemoteEntityGroup, RunBaseParams, RunCommand from flytekit.models.admin.common import Sort from flytekit.models.common import NamedEntityIdentifier -from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow +from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask + + +@dataclass +class VersionLevelParams(RunBaseParams): + """ + This class is used to store the parameters for the version command. + """ + + pass class InstanceDisplayCommand(click.RichCommand): + """ + Dummy command that displays the version of the entity. + """ + def __init__(self, name, h, **kwargs): super().__init__(name=name, help=h, **kwargs) class DynamicEntityVersionCommand(click.RichGroup, DynamicEntityLaunchCommand): + """ + Command that retrieves the versions of a remote entity. + """ + def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs): - super(click.RichGroup, self).__init__(name, h, entity_name, launcher, **kwargs) + DynamicEntityLaunchCommand.__init__(self, name, h, entity_name, launcher, **kwargs) def get_params(self, ctx: Context) -> typing.List[Parameter]: - """ - returns empty list to avoid parent adding task/workflow/launchplan params - """ - return [] + # we don't use super.get_params here, because DynamicEntityLaunchCommand.get_params adds the options of the entity + return click.RichGroup.get_params(self, ctx) def list_commands(self, ctx: click.Context): - run_params: RunLevelParams = ctx.obj + run_params: VersionLevelParams = ctx.obj named_entity = NamedEntityIdentifier(run_params.project, run_params.domain, ctx.info_name) _remote_instance: FlyteRemote = run_params.remote_instance() entity = self._fetch_entity(ctx) @@ -37,36 +53,33 @@ def list_commands(self, ctx: click.Context): sorted_entities, _ = _remote_instance.client.list_launch_plans_paginated( named_entity, sort_by=Sort("created_at", Sort.Direction.DESCENDING) ) - elif isinstance(entity, FlyteWorkflow): - sorted_entities, _ = _remote_instance.client.list_workflows_paginated( - named_entity, sort_by=Sort("created_at", Sort.Direction.DESCENDING) - ) else: raise ValueError(f"Unknown entity type {type(entity)}") - self._entity_dict = { - _entity.id.version: _entity.closure.created_at.strftime("%Y-%m-%d %H:%M:%S") for _entity in sorted_entities - } + parse_creation_time = ( + lambda x: x.closure.created_at.strftime("%Y-%m-%d %H:%M:%S") + if x.closure.created_at is not None + else "Unknown Time" + ) + self._entity_dict = {_entity.id.version: parse_creation_time(_entity) for _entity in sorted_entities} return self._entity_dict.keys() def get_command(self, ctx, version): - """ - returns version as command and created_at as help - """ if ctx.obj is None: ctx.obj = {} return InstanceDisplayCommand(name=version, h=f"Created At {self._entity_dict[version]}") + def invoke(self, ctx: Context) -> typing.Any: + pass + class RemoteEntityVersionGroup(RemoteEntityGroup): """ - click multicommand that retrieves launchplans from a remote flyte instance and display version of them. + click multicommand that retrieves launchplans/tasks from a remote flyte instance and display version of them. """ - def __init__(self, command_name: str): - super().__init__( - command_name, - ) + def __init__(self, command_name: str, h: str): + super().__init__(command_name, h) def get_command(self, ctx: click.Context, name: str): if self._command_name in [self.LAUNCHPLAN_COMMAND, self.WORKFLOW_COMMAND]: @@ -85,37 +98,40 @@ def get_command(self, ctx: click.Context, name: str): class VersionCommand(RunCommand): - _run_params: typing.Type[RunLevelParams] = RunLevelParams + _run_params: typing.Type[RunBaseParams] = VersionLevelParams def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._files = [] - def list_commands(self, ctx: click.Context, add_remote: bool = True): - self._files = sorted(self._files) - if add_remote: - self._files = self._files + [ - RemoteEntityGroup.LAUNCHPLAN_COMMAND, - RemoteEntityGroup.WORKFLOW_COMMAND, - RemoteEntityGroup.TASK_COMMAND, - ] + def list_commands(self, ctx: click.Context): + self._files = self._files + [ + RemoteEntityGroup.LAUNCHPLAN_COMMAND, + RemoteEntityGroup.WORKFLOW_COMMAND, + RemoteEntityGroup.TASK_COMMAND, + ] return self._files def get_command(self, ctx: click.Context, filename: str): + # call parent get_command to setup run_params super().get_command(ctx, filename) + entity_version_help = f"Show the versions of the specified {filename}." if filename == RemoteEntityGroup.LAUNCHPLAN_COMMAND: - return RemoteEntityVersionGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND) + return RemoteEntityVersionGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND, entity_version_help) elif filename == RemoteEntityGroup.WORKFLOW_COMMAND: - return RemoteEntityVersionGroup(RemoteEntityGroup.WORKFLOW_COMMAND) + return RemoteEntityVersionGroup(RemoteEntityGroup.WORKFLOW_COMMAND, entity_version_help) elif filename == RemoteEntityGroup.TASK_COMMAND: - return RemoteEntityVersionGroup(RemoteEntityGroup.TASK_COMMAND) + return RemoteEntityVersionGroup(RemoteEntityGroup.TASK_COMMAND, entity_version_help) else: raise NotImplementedError(f"File {filename} not found") -_run_help = """ -Show the versions of the entity. +_version_help = """ +Show the versions of the specified ``remote-task``, ``remote-launchplan``, or ``remote-workflow``. +Usage resembles the ``pyflyte run`` command, but instead of running the task, launchplan, or workflow, +it will display the versions of the remote entities and the time they were created. """ version = VersionCommand( name="show-versions", - help=_run_help, + help=_version_help, ) diff --git a/tests/flytekit/unit/cli/pyflyte/test_versions.py b/tests/flytekit/unit/cli/pyflyte/test_versions.py new file mode 100644 index 00000000000..6d7a7eab373 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_versions.py @@ -0,0 +1,47 @@ +import datetime +import os + +import mock +import pytest +from click.testing import CliRunner + +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.models import task as _task +from flytekit.models.core.identifier import Identifier as _identifier +from flytekit.models.core.identifier import ResourceType as _resource_type +from flytekit.remote import FlyteTask + +pytest.importorskip("pandas") + +WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") +REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py" +IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") +DIR_NAME = os.path.dirname(os.path.realpath(__file__)) + + +@mock.patch("flytekit.clis.sdk_in_container.versions.DynamicEntityVersionCommand._fetch_entity") +def test_pyflyte_version(mock_entity): + runner = CliRunner() + mock_entity.return_value = mock.MagicMock(spec=FlyteTask) + + created_at = datetime.datetime(2021, 1, 1) + mock_closure = _task.TaskClosure(mock.MagicMock(spec=_task.CompiledTask), created_at=created_at) + mock_tasks = [ + _task.Task(id=_identifier(_resource_type.TASK, "p1", "d1", "my_task", "my_version"), closure=mock_closure) + ] + + with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient.list_tasks_paginated") as mock_list_tasks: + mock_list_tasks.return_value = (mock_tasks, None) + result = runner.invoke(pyflyte.main, ["show-versions", "remote-task", "any_task"], catch_exceptions=False) + + assert "my_version" in result.output + assert created_at.strftime("%Y-%m-%d %H:%M:%S") in result.output + assert result.exit_code == 0 + + +def test_pyflyte_version_no_workflows(): + with mock.patch("flytekit.configuration.plugin.FlyteRemote"): + runner = CliRunner() + result = runner.invoke(pyflyte.main, ["show-versions", "remote-workflow"], catch_exceptions=False) + + assert result.exit_code == 0