From 65dd686c21181b2e5b9ee93cd342e6594e272456 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Sat, 2 Nov 2024 19:05:47 +0100 Subject: [PATCH] wip --- src/armonik_cli/commands/common.py | 89 ++++++++++++++++++++- src/armonik_cli/commands/sessions.py | 111 ++++++++++---------------- src/armonik_cli/errors.py | 1 + src/armonik_cli/utils.py | 112 ++++++++++++++++++++++++++- 4 files changed, 239 insertions(+), 74 deletions(-) diff --git a/src/armonik_cli/commands/common.py b/src/armonik_cli/commands/common.py index b498fd7..08077b9 100644 --- a/src/armonik_cli/commands/common.py +++ b/src/armonik_cli/commands/common.py @@ -1,19 +1,65 @@ +import logging import re import rich_click as click from datetime import timedelta +from functools import wraps, partial +from pathlib import Path from typing import cast, Tuple, Union +from armonik.common.channel import create_channel + +from armonik_cli.errors import error_handler +from armonik_cli.utils import get_logger, reconcile_connection_details + endpoint_option = click.option( "-e", "--endpoint", type=str, - required=True, + required=False, help="Endpoint of the cluster to connect to.", + envvar="ARMONIK__ENDPOINT", metavar="ENDPOINT", ) +ca_option = click.option( + "--ca", + "--certificate-authority", + type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True), + required=False, + help="Path to the certificate authority to read.", + envvar="ARMONIK__CA", + metavar="CA_PATH", +) +cert_option = click.option( + "--cert", + "--client-certificate", + type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True), + required=False, + help="Path to the client certificate to read.", + envvar="ARMONIK__CERT", + metavar="CERT_PATH", +) +key_option = click.option( + "--key", + "--client-key", + type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True), + required=False, + help="Path to the client key to read.", + envvar="ARMONIK__KEY", + metavar="KEY_PATH", +) +config_option = click.option( + "-c", + "--config", + "optional_config_file", + type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True), + required=False, + help="Path to a third-party configuration file.", + envvar="ARMONIK__CONFIG", + metavar="CONFIG_PATH", +) output_option = click.option( "-o", "--output", @@ -121,3 +167,44 @@ def _parse_time_delta(time_str: str) -> timedelta: seconds=int(sec), milliseconds=int(microseconds.ljust(3, "0")), # Ensure 3 digits for milliseconds ) + + +def base_command(func): + if not func: + return partial(base_command) + + @endpoint_option + @ca_option + @cert_option + @key_option + @config_option + @output_option + @debug_option + @error_handler + @wraps(func) + def wrapper( + endpoint: str, + ca: Union[str, None], + cert: Union[str, None], + key: Union[str, None], + optional_config_file: Union[str, None], + output: str, + debug: bool, + *args, + **kwargs, + ): + logger = get_logger(debug) + config = reconcile_connection_details( + { + "endpoint": endpoint, + "certificate_authority": ca, + "client_certificate": cert, + "client_key": key, + }, + optional_config_file, + logger, + ) + channel_ctx = create_channel(config.pop("endpoint"), **config) + return func(channel_ctx, logger, output, *args, **kwargs) + + return wrapper diff --git a/src/armonik_cli/commands/sessions.py b/src/armonik_cli/commands/sessions.py index daabecc..fa4f598 100644 --- a/src/armonik_cli/commands/sessions.py +++ b/src/armonik_cli/commands/sessions.py @@ -1,4 +1,5 @@ -import grpc +import logging + import rich_click as click from datetime import timedelta @@ -6,13 +7,12 @@ from armonik.client.sessions import ArmoniKSessions from armonik.common import SessionStatus, Session, TaskOptions +from armonik.common.channel import create_channel +from grpc import Channel from armonik_cli.console import console -from armonik_cli.errors import error_handler from armonik_cli.commands.common import ( - endpoint_option, - output_option, - debug_option, + base_command, KeyValuePairParam, TimeDeltaParam, ) @@ -29,32 +29,26 @@ def sessions() -> None: @sessions.command() -@endpoint_option -@output_option -@debug_option -@error_handler -def list(endpoint: str, output: str, debug: bool) -> None: +@base_command +def list(channel_ctx: Channel, logger: logging.Logger, output: str) -> None: + # def list(endpoint: str, ca: Union[Path, None], cert: Union[Path, None], key: Union[Path, None], other_config_path: Union[Path, None], output: str, debug: bool) -> None: """List the sessions of an ArmoniK cluster.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) total, sessions = sessions_client.list_sessions() - if total > 0: - sessions = [_clean_up_status(s) for s in sessions] - console.formatted_print(sessions, format=output, table_cols=SESSION_TABLE_COLS) + logger.info(f"{total} sessions found.") - console.print(f"\n{total} sessions found.") + sessions = [_clean_up_status(s) for s in sessions] + console.formatted_print(sessions, format=output, table_cols=SESSION_TABLE_COLS) @sessions.command() -@endpoint_option -@output_option -@debug_option @session_argument -@error_handler -def get(endpoint: str, output: str, session_id: str, debug: bool) -> None: +@base_command +def get(channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str) -> None: """Get details of a given session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.get_session(session_id=session_id) session = _clean_up_status(session) @@ -62,7 +56,6 @@ def get(endpoint: str, output: str, session_id: str, debug: bool) -> None: @sessions.command() -@endpoint_option @click.option( "--max-retries", type=int, @@ -130,11 +123,9 @@ def get(endpoint: str, output: str, session_id: str, debug: bool) -> None: help="Additional default options.", metavar="KEY=VALUE", ) -@output_option -@debug_option -@error_handler +@base_command def create( - endpoint: str, + channel_ctx: Channel, logger: logging.Logger, output: str, max_retries: int, max_duration: timedelta, priority: int, @@ -146,11 +137,9 @@ def create( application_service: Union[str, None], engine_type: Union[str, None], option: Union[List[Tuple[str, str]], None], - output: str, - debug: bool, ) -> None: """Create a new session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session_id = sessions_client.create_session( default_task_options=TaskOptions( @@ -173,15 +162,12 @@ def create( @sessions.command() -@endpoint_option @click.confirmation_option("--confirm", prompt="Are you sure you want to cancel this session?") -@output_option -@debug_option @session_argument -@error_handler -def cancel(endpoint: str, output: str, session_id: str, debug: bool) -> None: +@base_command +def cancel(channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str) -> None: """Cancel a session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.cancel_session(session_id=session_id) session = _clean_up_status(session) @@ -189,14 +175,11 @@ def cancel(endpoint: str, output: str, session_id: str, debug: bool) -> None: @sessions.command() -@endpoint_option -@output_option -@debug_option @session_argument -@error_handler -def pause(endpoint: str, output: str, session_id: str, debug: bool) -> None: +@base_command +def pause(channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str) -> None: """Pause a session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.pause_session(session_id=session_id) session = _clean_up_status(session) @@ -204,14 +187,11 @@ def pause(endpoint: str, output: str, session_id: str, debug: bool) -> None: @sessions.command() -@endpoint_option -@output_option -@debug_option @session_argument -@error_handler -def resume(endpoint: str, output: str, session_id: str, debug: bool) -> None: +@base_command +def resume(channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str) -> None: """Resume a session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.resume_session(session_id=session_id) session = _clean_up_status(session) @@ -219,15 +199,12 @@ def resume(endpoint: str, output: str, session_id: str, debug: bool) -> None: @sessions.command() -@endpoint_option @click.confirmation_option("--confirm", prompt="Are you sure you want to close this session?") -@output_option -@debug_option @session_argument -@error_handler -def close(endpoint: str, output: str, session_id: str, debug: bool) -> None: +@base_command +def close(channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str) -> None: """Close a session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.close_session(session_id=session_id) session = _clean_up_status(session) @@ -235,15 +212,12 @@ def close(endpoint: str, output: str, session_id: str, debug: bool) -> None: @sessions.command() -@endpoint_option @click.confirmation_option("--confirm", prompt="Are you sure you want to purge this session?") -@output_option -@debug_option @session_argument -@error_handler -def purge(endpoint: str, output: str, session_id: str, debug: bool) -> None: +@base_command +def purge(channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str) -> None: """Purge a session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.purge_session(session_id=session_id) session = _clean_up_status(session) @@ -251,15 +225,12 @@ def purge(endpoint: str, output: str, session_id: str, debug: bool) -> None: @sessions.command() -@endpoint_option @click.confirmation_option("--confirm", prompt="Are you sure you want to delete this session?") -@output_option -@debug_option @session_argument -@error_handler -def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None: +@base_command +def delete(channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str) -> None: """Delete a session and associated data from the cluster.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.delete_session(session_id=session_id) session = _clean_up_status(session) @@ -267,8 +238,6 @@ def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None: @sessions.command() -@endpoint_option -@session_argument @click.option( "--clients-only", is_flag=True, @@ -281,14 +250,12 @@ def delete(endpoint: str, output: str, session_id: str, debug: bool) -> None: default=False, help="Prevent only workers from submitting new tasks in the session.", ) -@output_option -@debug_option -@error_handler +@base_command def stop_submission( - endpoint: str, session_id: str, clients_only: bool, workers_only: bool, output: str, debug: bool + channel_ctx: Channel, logger: logging.Logger, output: str, session_id: str, clients_only: bool, workers_only: bool ) -> None: """Stop clients and/or workers from submitting new tasks in a session.""" - with grpc.insecure_channel(endpoint) as channel: + with channel_ctx as channel: sessions_client = ArmoniKSessions(channel) session = sessions_client.stop_submission_session( session_id=session_id, client=clients_only, worker=workers_only diff --git a/src/armonik_cli/errors.py b/src/armonik_cli/errors.py index 3463ecb..98724ff 100644 --- a/src/armonik_cli/errors.py +++ b/src/armonik_cli/errors.py @@ -42,6 +42,7 @@ def wrapper(*args, **kwargs): else: raise InternalError("An internal fatal error occured.") except Exception: + # TODO: now 'debug' is not provided to functions if "debug" in kwargs and kwargs["debug"]: console.print_exception() else: diff --git a/src/armonik_cli/utils.py b/src/armonik_cli/utils.py index e99a5eb..00248e5 100644 --- a/src/armonik_cli/utils.py +++ b/src/armonik_cli/utils.py @@ -1,10 +1,120 @@ import json +import logging + +import click +import jsonschema from datetime import datetime, timedelta -from typing import Dict, Union, Any +from pathlib import Path +from typing import Dict, Union, Any, Optional from armonik.common import Session, TaskOptions from google._upb._message import ScalarMapContainer +from rich.logging import RichHandler + + +APP_NAME = "armonik" +MSG_FORMAT = "%(message)s" +DATE_FORMAT = "%Y-%m-%dT%H:%M:%S" + + +class ConfigFile: + _schema = schema = { + "type": "object", + "properties": { + "endpoint": {"anyOf": [{"type": "string", "format": "hostname"},{"type": "null"}]}, + "certificate_authority": {"anyOf": [{"type": "string", "format": "uri-reference"},{"type": "null"}]}, + "client_certificate": {"anyOf": [{"type": "string", "format": "uri-reference"},{"type": "null"}]}, + "client_key": {"anyOf": [{"type": "string", "format": "uri-reference"},{"type": "null"}]}, + }, + "required": [], + } + default_path = Path(click.get_app_dir(app_name=APP_NAME)) / "config" + + def __init__(self, loc: Optional[str] = None) -> None: + self.loc = Path(loc) if loc is not None else self.default_path + + def create_if_not_exists(self): + self.loc.parent.mkdir(exist_ok=True) + self.loc.touch() + with self.loc.open("w") as config_file: + config_file.write(json.dumps({})) + + def load(self) -> dict[str, str]: + if not self.loc.exists() and self.loc == self.default_path: + self.create_if_not_exists() + with self.loc.open("r") as config_file: + data = json.loads(config_file.read()) + jsonschema.validate(instance=data, schema=self._schema) + return data + + +def reconcile_connection_details( + cli_params: Dict[str, Optional[Union[str, Path]]], + optional_config_file: Optional[Path], + logger: logging.Logger, +) -> Dict[str, Union[str, Path]]: + """ + Reconciles parameters from command-line, optional config file, and default config file. + Command-line params have highest priority, then optional config, then default config. + + Args: + cli_params (Dict[str, Optional[str]]): Parameters provided via command-line. + optional_config_file (str): Path to the optional config file. + default_config_file (str): Path to the default config file. + + Returns: + Dict[str, str]: Final parameters with reconciled values. + """ + + # Read config files in priority order: optional config file, then default config + if optional_config_file: + optional_config = ConfigFile(optional_config_file).load() + logger.debug(f"Loaded optional config file: {optional_config_file}") + else: + optional_config = None + + default_config = ConfigFile().load() + logger.debug(f"Loaded default config file: {ConfigFile.default_path}") + + final_params = {} + + # Resolve each parameter in priority order + for key in cli_params.keys(): + # Priority 1: command-line parameters + if cli_params.get(key) is not None: + ctx = click.get_current_context() + param_source = ctx.get_parameter_source(param_name) + if param_source == click.core.ParameterSource.COMMANDLINE: + source = "command-line" + elif param_source == click.core.ParameterSource.ENVIRONMENT: + source = "environment variable" + final_params[key] = cli_params[key] + logger.info(f"Parameter '{key}' retrieved from {source}.") + # Priority 2: optional config file, if exists + elif optional_config and optional_config.get(key) is not None: + final_params[key] = optional_config[key] + logger.info(f"Parameter '{key}' retrieved from config file {optional_config.loc}.") + # Priority 3: default config file + elif default_config.get(key) is not None: + final_params[key] = default_config[key] + logger.info(f"Parameter '{key}' taken from default config file.") + else: + if key == "endpoint": + raise click.exceptions.UsageError("No endpoint provided.") + final_params[key] = None + logger.info(f"Parameter '{key}' is missing.") + + return final_params + + +def get_logger(debug: bool) -> None: + level = logging.DEBUG if debug else logging.INFO + logging.basicConfig( + level=level, format=MSG_FORMAT, datefmt=DATE_FORMAT, handlers=[RichHandler()] + ) + logging.captureWarnings(True) + return logging.getLogger("armonik_cli") class CLIJSONEncoder(json.JSONEncoder):