Skip to content

Commit

Permalink
remove unused params and add test
Browse files Browse the repository at this point in the history
Signed-off-by: novahow <[email protected]>
  • Loading branch information
novahow committed Mar 27, 2024
1 parent 6800b28 commit a8a6ede
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 85 deletions.
116 changes: 69 additions & 47 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,57 @@ class RunLevelComputedParams:


@dataclass
class RunLevelParams(PyFlyteParams):
class RunBaseParams(PyFlyteParams):
"""
This class is used to store the parameters that are used to run a workflow / task / launchplan.
This task contains basic parameters used in pyflyte run and pyflyte show-versions
"""

project: str = make_click_option_field(project_option)
domain: str = make_click_option_field(domain_option)
limit: int = make_click_option_field(
click.Option(
param_decls=["--limit", "limit"],
required=False,
type=int,
default=50,
hidden=True,
show_default=True,
help="Use this to limit number of entities to fetch",
)
)
_remote: typing.Optional[FlyteRemote] = None
remote: bool = field(default=False, init=False)

def remote_instance(self) -> FlyteRemote:
if self._remote is None:
data_upload_location = None
if self.is_remote:
data_upload_location = remote_fs.REMOTE_PLACEHOLDER
self._remote = get_plugin().get_remote(self.config_file, self.project, self.domain, data_upload_location)
return self._remote

@property
def is_remote(self) -> bool:
return self.remote

@classmethod
def from_dict(cls, d: typing.Dict[str, typing.Any]) -> "RunLevelParams":
return cls(**d)

@classmethod
def options(cls) -> typing.List[click.Option]:
"""
Return the set of base parameters added to every pyflyte run workflow subcommand.
"""
return [get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]


@dataclass
class RunLevelParams(RunBaseParams):
"""
This class is used to store the parameters that are used to run a workflow / task / launchplan.
"""

destination_dir: str = make_click_option_field(
click.Option(
param_decls=["--destination-dir", "destination_dir"],
Expand Down Expand Up @@ -238,17 +282,6 @@ class RunLevelParams(PyFlyteParams):
help="Whether to register and run the workflow on a Flyte deployment",
)
)
limit: int = make_click_option_field(
click.Option(
param_decls=["--limit", "limit"],
required=False,
type=int,
default=50,
hidden=True,
show_default=True,
help="Use this to limit number of entities to fetch",
)
)
cluster_pool: str = make_click_option_field(
click.Option(
param_decls=["--cluster-pool", "cluster_pool"],
Expand All @@ -259,30 +292,6 @@ class RunLevelParams(PyFlyteParams):
)
)
computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams)
_remote: typing.Optional[FlyteRemote] = None

def remote_instance(self) -> FlyteRemote:
if self._remote is None:
data_upload_location = None
if self.is_remote:
data_upload_location = remote_fs.REMOTE_PLACEHOLDER
self._remote = get_plugin().get_remote(self.config_file, self.project, self.domain, data_upload_location)
return self._remote

@property
def is_remote(self) -> bool:
return self.remote

@classmethod
def from_dict(cls, d: typing.Dict[str, typing.Any]) -> "RunLevelParams":
return cls(**d)

@classmethod
def options(cls) -> typing.List[click.Option]:
"""
Return the set of base parameters added to every pyflyte run workflow subcommand.
"""
return [get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]


def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]:
Expand Down Expand Up @@ -577,7 +586,7 @@ def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs)
self._entity = None

def _looped_fetch_entity(
self, entity_fetch_func: typing.Callable, run_level_params: RunLevelParams
self, entity_fetch_func: typing.Callable, run_level_params: RunBaseParams
) -> typing.Union[FlyteLaunchPlan, FlyteTask]:
version_splits = self._entity_name.split(RemoteVersion.splitter)
for _version_seg_len in range(len(version_splits)):
Expand All @@ -599,8 +608,8 @@ def _looped_fetch_entity(
def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]:
if self._entity:
return self._entity
run_level_params: RunLevelParams = ctx.obj
r = run_level_params.remote_instance()
run_level_params: RunBaseParams = ctx.obj
r: FlyteRemote = run_level_params.remote_instance()
if self._launcher == self.LP_LAUNCHER:
entity = self._looped_fetch_entity(r.fetch_launch_plan, run_level_params)
else:
Expand Down Expand Up @@ -679,10 +688,10 @@ class RemoteEntityGroup(click.RichGroup):
WORKFLOW_COMMAND = "remote-workflow"
TASK_COMMAND = "remote-task"

def __init__(self, command_name: str):
def __init__(self, command_name: str, h: str):
super().__init__(
name=command_name,
help=f"Retrieve {command_name} from a remote flyte instance and execute them.",
help=h,
params=[
click.Option(
["--limit", "limit"],
Expand Down Expand Up @@ -855,7 +864,7 @@ class RunCommand(click.RichGroup):
A click command group for registering and executing flyte workflows & tasks in a file.
"""

_run_params: typing.Type[RunLevelParams] = RunLevelParams
_run_params: typing.Type[RunBaseParams] = RunLevelParams

def __init__(self, *args, **kwargs):
if "params" not in kwargs:
Expand All @@ -882,16 +891,29 @@ def get_command(self, ctx, filename):
ctx.obj = {}
if not isinstance(ctx.obj, self._run_params):
params = {}
# NOTE: ctx.params: RunLevelParams
params.update(ctx.params)
params.update(ctx.obj)
ctx.obj = self._run_params.from_dict(params)
entity_group_help_msg = (
"Retrieve {command_name} from a remote flyte instance and execute them.\n"
"You may attach a version behind the {command_name} name to execute a specific version, \n"
"e.g. {command_name}:version1"
)
if filename == RemoteEntityGroup.LAUNCHPLAN_COMMAND:
return RemoteEntityGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND)
return RemoteEntityGroup(
RemoteEntityGroup.LAUNCHPLAN_COMMAND,
entity_group_help_msg.format(command_name=RemoteEntityGroup.LAUNCHPLAN_COMMAND),
)
elif filename == RemoteEntityGroup.WORKFLOW_COMMAND:
return RemoteEntityGroup(RemoteEntityGroup.WORKFLOW_COMMAND)
return RemoteEntityGroup(
RemoteEntityGroup.WORKFLOW_COMMAND,
entity_group_help_msg.format(command_name=RemoteEntityGroup.WORKFLOW_COMMAND),
)
elif filename == RemoteEntityGroup.TASK_COMMAND:
return RemoteEntityGroup(RemoteEntityGroup.TASK_COMMAND)
return RemoteEntityGroup(
RemoteEntityGroup.TASK_COMMAND,
entity_group_help_msg.format(command_name=RemoteEntityGroup.TASK_COMMAND),
)
return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}")


Expand Down
92 changes: 54 additions & 38 deletions flytekit/clis/sdk_in_container/versions.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,47 @@
import typing
from dataclasses import dataclass

import rich_click as click
from click import Context, Parameter

from flytekit.clis.sdk_in_container.run import DynamicEntityLaunchCommand, RemoteEntityGroup, RunCommand, RunLevelParams
from flytekit.clis.sdk_in_container.run import DynamicEntityLaunchCommand, RemoteEntityGroup, RunBaseParams, RunCommand
from flytekit.models.admin.common import Sort
from flytekit.models.common import NamedEntityIdentifier
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask


@dataclass
class VersionLevelParams(RunBaseParams):
"""
This class is used to store the parameters for the version command.
"""

pass


class InstanceDisplayCommand(click.RichCommand):
"""
Dummy command that displays the version of the entity.
"""

def __init__(self, name, h, **kwargs):
super().__init__(name=name, help=h, **kwargs)


class DynamicEntityVersionCommand(click.RichGroup, DynamicEntityLaunchCommand):
"""
Command that retrieves the versions of a remote entity.
"""

def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs):
super(click.RichGroup, self).__init__(name, h, entity_name, launcher, **kwargs)
DynamicEntityLaunchCommand.__init__(self, name, h, entity_name, launcher, **kwargs)

def get_params(self, ctx: Context) -> typing.List[Parameter]:
"""
returns empty list to avoid parent adding task/workflow/launchplan params
"""
return []
# we don't use super.get_params here, because DynamicEntityLaunchCommand.get_params adds the options of the entity
return click.RichGroup.get_params(self, ctx)

def list_commands(self, ctx: click.Context):
run_params: RunLevelParams = ctx.obj
run_params: VersionLevelParams = ctx.obj
named_entity = NamedEntityIdentifier(run_params.project, run_params.domain, ctx.info_name)
_remote_instance: FlyteRemote = run_params.remote_instance()
entity = self._fetch_entity(ctx)
Expand All @@ -37,36 +53,33 @@ def list_commands(self, ctx: click.Context):
sorted_entities, _ = _remote_instance.client.list_launch_plans_paginated(
named_entity, sort_by=Sort("created_at", Sort.Direction.DESCENDING)
)
elif isinstance(entity, FlyteWorkflow):
sorted_entities, _ = _remote_instance.client.list_workflows_paginated(
named_entity, sort_by=Sort("created_at", Sort.Direction.DESCENDING)
)
else:
raise ValueError(f"Unknown entity type {type(entity)}")

self._entity_dict = {
_entity.id.version: _entity.closure.created_at.strftime("%Y-%m-%d %H:%M:%S") for _entity in sorted_entities
}
parse_creation_time = (
lambda x: x.closure.created_at.strftime("%Y-%m-%d %H:%M:%S")
if x.closure.created_at is not None
else "Unknown Time"
)
self._entity_dict = {_entity.id.version: parse_creation_time(_entity) for _entity in sorted_entities}
return self._entity_dict.keys()

def get_command(self, ctx, version):
"""
returns version as command and created_at as help
"""
if ctx.obj is None:
ctx.obj = {}
return InstanceDisplayCommand(name=version, h=f"Created At {self._entity_dict[version]}")

def invoke(self, ctx: Context) -> typing.Any:
pass


class RemoteEntityVersionGroup(RemoteEntityGroup):
"""
click multicommand that retrieves launchplans from a remote flyte instance and display version of them.
click multicommand that retrieves launchplans/tasks from a remote flyte instance and display version of them.
"""

def __init__(self, command_name: str):
super().__init__(
command_name,
)
def __init__(self, command_name: str, h: str):
super().__init__(command_name, h)

def get_command(self, ctx: click.Context, name: str):
if self._command_name in [self.LAUNCHPLAN_COMMAND, self.WORKFLOW_COMMAND]:
Expand All @@ -85,37 +98,40 @@ def get_command(self, ctx: click.Context, name: str):


class VersionCommand(RunCommand):
_run_params: typing.Type[RunLevelParams] = RunLevelParams
_run_params: typing.Type[RunBaseParams] = VersionLevelParams

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._files = []

def list_commands(self, ctx: click.Context, add_remote: bool = True):
self._files = sorted(self._files)
if add_remote:
self._files = self._files + [
RemoteEntityGroup.LAUNCHPLAN_COMMAND,
RemoteEntityGroup.WORKFLOW_COMMAND,
RemoteEntityGroup.TASK_COMMAND,
]
def list_commands(self, ctx: click.Context):
self._files = self._files + [
RemoteEntityGroup.LAUNCHPLAN_COMMAND,
RemoteEntityGroup.WORKFLOW_COMMAND,
RemoteEntityGroup.TASK_COMMAND,
]
return self._files

def get_command(self, ctx: click.Context, filename: str):
# call parent get_command to setup run_params
super().get_command(ctx, filename)
entity_version_help = f"Show the versions of the specified {filename}."
if filename == RemoteEntityGroup.LAUNCHPLAN_COMMAND:
return RemoteEntityVersionGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND)
return RemoteEntityVersionGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND, entity_version_help)
elif filename == RemoteEntityGroup.WORKFLOW_COMMAND:
return RemoteEntityVersionGroup(RemoteEntityGroup.WORKFLOW_COMMAND)
return RemoteEntityVersionGroup(RemoteEntityGroup.WORKFLOW_COMMAND, entity_version_help)
elif filename == RemoteEntityGroup.TASK_COMMAND:
return RemoteEntityVersionGroup(RemoteEntityGroup.TASK_COMMAND)
return RemoteEntityVersionGroup(RemoteEntityGroup.TASK_COMMAND, entity_version_help)
else:
raise NotImplementedError(f"File {filename} not found")


_run_help = """
Show the versions of the entity.
_version_help = """
Show the versions of the specified ``remote-task``, ``remote-launchplan``, or ``remote-workflow``.
Usage resembles the ``pyflyte run`` command, but instead of running the task, launchplan, or workflow,
it will display the versions of the remote entities and the time they were created.
"""
version = VersionCommand(
name="show-versions",
help=_run_help,
help=_version_help,
)
47 changes: 47 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import datetime
import os

import mock
import pytest
from click.testing import CliRunner

from flytekit.clis.sdk_in_container import pyflyte
from flytekit.models import task as _task
from flytekit.models.core.identifier import Identifier as _identifier
from flytekit.models.core.identifier import ResourceType as _resource_type
from flytekit.remote import FlyteTask

pytest.importorskip("pandas")

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py"
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
DIR_NAME = os.path.dirname(os.path.realpath(__file__))


@mock.patch("flytekit.clis.sdk_in_container.versions.DynamicEntityVersionCommand._fetch_entity")
def test_pyflyte_version(mock_entity):
runner = CliRunner()
mock_entity.return_value = mock.MagicMock(spec=FlyteTask)

created_at = datetime.datetime(2021, 1, 1)
mock_closure = _task.TaskClosure(mock.MagicMock(spec=_task.CompiledTask), created_at=created_at)
mock_tasks = [
_task.Task(id=_identifier(_resource_type.TASK, "p1", "d1", "my_task", "my_version"), closure=mock_closure)
]

with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient.list_tasks_paginated") as mock_list_tasks:
mock_list_tasks.return_value = (mock_tasks, None)
result = runner.invoke(pyflyte.main, ["show-versions", "remote-task", "any_task"], catch_exceptions=False)

assert "my_version" in result.output
assert created_at.strftime("%Y-%m-%d %H:%M:%S") in result.output
assert result.exit_code == 0


def test_pyflyte_version_no_workflows():
with mock.patch("flytekit.configuration.plugin.FlyteRemote"):
runner = CliRunner()
result = runner.invoke(pyflyte.main, ["show-versions", "remote-workflow"], catch_exceptions=False)

assert result.exit_code == 0

0 comments on commit a8a6ede

Please sign in to comment.