Skip to content

Commit

Permalink
feat: add base_command decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
qdelamea-aneo committed Nov 8, 2024
1 parent 7110c1f commit 92b6720
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 73 deletions.
23 changes: 0 additions & 23 deletions src/armonik_cli/commands/common.py

This file was deleted.

59 changes: 12 additions & 47 deletions src/armonik_cli/commands/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from armonik.client.sessions import ArmoniKSessions
from armonik.common import SessionStatus, Session, TaskOptions

from armonik_cli.core import console, error_handler, KeyValuePairParam, TimeDeltaParam
from armonik_cli.commands.common import (
endpoint_option,
output_option,
debug_option,
)
from armonik_cli.core import console, base_command, KeyValuePairParam, TimeDeltaParam


SESSION_TABLE_COLS = [("ID", "SessionId"), ("Status", "Status"), ("CreatedAt", "CreatedAt")]
Expand All @@ -26,10 +21,7 @@ def sessions() -> None:


@sessions.command()
@endpoint_option
@output_option
@debug_option
@error_handler
@base_command
def list(endpoint: str, output: str, debug: bool) -> None:
"""List the sessions of an ArmoniK cluster."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -44,11 +36,8 @@ def list(endpoint: str, output: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@output_option
@debug_option
@session_argument
@error_handler
@base_command
def get(endpoint: str, output: str, session_id: str, debug: bool) -> None:
"""Get details of a given session."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -59,7 +48,6 @@ def get(endpoint: str, output: str, session_id: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@click.option(
"--max-retries",
type=int,
Expand Down Expand Up @@ -127,9 +115,7 @@ def get(endpoint: str, output: str, session_id: str, debug: bool) -> None:
help="Additional default options.",
metavar="KEY=VALUE",
)
@output_option
@debug_option
@error_handler
@base_command
def create(
endpoint: str,
max_retries: int,
Expand Down Expand Up @@ -170,12 +156,9 @@ def create(


@sessions.command()
@endpoint_option
@click.confirmation_option("--confirm", prompt="Are you sure you want to cancel this session?")
@output_option
@debug_option
@session_argument
@error_handler
@base_command
def cancel(endpoint: str, output: str, session_id: str, debug: bool) -> None:
"""Cancel a session."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -186,11 +169,8 @@ def cancel(endpoint: str, output: str, session_id: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@output_option
@debug_option
@session_argument
@error_handler
@base_command
def pause(endpoint: str, output: str, session_id: str, debug: bool) -> None:
"""Pause a session."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -201,11 +181,8 @@ def pause(endpoint: str, output: str, session_id: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@output_option
@debug_option
@session_argument
@error_handler
@base_command
def resume(endpoint: str, output: str, session_id: str, debug: bool) -> None:
"""Resume a session."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -216,12 +193,9 @@ def resume(endpoint: str, output: str, session_id: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@click.confirmation_option("--confirm", prompt="Are you sure you want to close this session?")
@output_option
@debug_option
@session_argument
@error_handler
@base_command
def close(endpoint: str, output: str, session_id: str, debug: bool) -> None:
"""Close a session."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -232,12 +206,9 @@ def close(endpoint: str, output: str, session_id: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@click.confirmation_option("--confirm", prompt="Are you sure you want to purge this session?")
@output_option
@debug_option
@session_argument
@error_handler
@base_command
def purge(endpoint: str, output: str, session_id: str, debug: bool) -> None:
"""Purge a session."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -248,12 +219,9 @@ def purge(endpoint: str, output: str, session_id: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@click.confirmation_option("--confirm", prompt="Are you sure you want to delete this session?")
@output_option
@debug_option
@session_argument
@error_handler
@base_command
def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None:
"""Delete a session and associated data from the cluster."""
with grpc.insecure_channel(endpoint) as channel:
Expand All @@ -264,8 +232,6 @@ def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None:


@sessions.command()
@endpoint_option
@session_argument
@click.option(
"--clients-only",
is_flag=True,
Expand All @@ -278,9 +244,8 @@ def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None:
default=False,
help="Prevent only workers from submitting new tasks in the session.",
)
@output_option
@debug_option
@error_handler
@session_argument
@base_command
def stop_submission(
endpoint: str, session_id: str, clients_only: bool, workers_only: bool, output: str, debug: bool
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/armonik_cli/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from armonik_cli.core.console import console
from armonik_cli.core.decorators import error_handler
from armonik_cli.core.decorators import base_command
from armonik_cli.core.params import KeyValuePairParam, TimeDeltaParam


__all__ = ["error_handler", "KeyValuePairParam", "TimeDeltaParam", "console"]
__all__ = ["base_command", "KeyValuePairParam", "TimeDeltaParam", "console"]
68 changes: 67 additions & 1 deletion src/armonik_cli/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@


def error_handler(func=None):
"""A decorator to manage the correct display of errors.."""
"""Decorator to ensure correct display of errors.
Args:
func: The command function to be decorated. If None, a partial function is returned,
allowing the decorator to be used with parentheses.
Returns:
The wrapped function with added CLI options.
"""
# Allow to call the decorator with parenthesis.
if not func:
return partial(error_handler)
Expand All @@ -34,3 +42,61 @@ def wrapper(*args, **kwargs):
raise InternalError("An internal fatal error occured.")

return wrapper


def base_command(func=None):
"""Decorator to add common CLI options to a Click command function, including
'endpoint', 'output', and 'debug'. These options are automatically passed
as arguments to the decorated function.
The following options are added to the command:
- `--endpoint` (required): Specifies the cluster endpoint.
- `--output`: Sets the output format, with options 'yaml', 'json', or 'table' (default is 'json').
- `--debug`: Enables debug mode, printing additional logs if set.
Warning:
If the decorated function has parameters with the same names as the options added by
this decorator, this can lead to conflicts and unpredictable behavior.
Args:
func: The command function to be decorated. If None, a partial function is returned,
allowing the decorator to be used with parentheses.
Returns:
The wrapped function with added CLI options.
"""

# Allow to call the decorator with parenthesis.
if not func:
return partial(base_command)

# Define the wrapper function with added Click options
@click.option(
"-e",
"--endpoint",
type=str,
required=True,
help="Endpoint of the cluster to connect to.",
metavar="ENDPOINT",
)
@click.option(
"-o",
"--output",
type=click.Choice(["yaml", "json", "table"], case_sensitive=False),
default="json",
show_default=True,
help="Commands output format.",
metavar="FORMAT",
)
@click.option(
"--debug", is_flag=True, default=False, help="Print debug logs and internal errors."
)
@error_handler
@wraps(func)
def wrapper(endpoint: str, output: str, debug: bool, *args, **kwargs):
kwargs["endpoint"] = endpoint
kwargs["output"] = output
kwargs["debug"] = debug
return func(*args, **kwargs)

return wrapper

0 comments on commit 92b6720

Please sign in to comment.