From f77392ff65946344f414630a1d8e523811ba8b5b Mon Sep 17 00:00:00 2001 From: Miauwkeru Date: Fri, 5 Jul 2024 10:54:39 +0000 Subject: [PATCH] Add optional boolean value to all store_true instances --- acquire/acquire.py | 10 ++++------ acquire/utils.py | 24 ++++++++++++++---------- tests/test_acquire_command.py | 25 ++++++++++++++++++++----- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/acquire/acquire.py b/acquire/acquire.py index d52f4e3d..2d7a161f 100644 --- a/acquire/acquire.py +++ b/acquire/acquire.py @@ -215,7 +215,7 @@ def wrapper(module_cls: type[Module]) -> type[Module]: desc = module_cls.DESC or name kwargs["help"] = f"acquire {desc}" - kwargs["action"] = "store_true" + kwargs["action"] = argparse.BooleanOptionalAction kwargs["dest"] = name.lower() module_cls.__modname__ = name @@ -661,15 +661,13 @@ def recyclebin_filter(path: fsutil.TargetPath) -> bool: @register_module("--recyclebin") @module_arg( "--large-files", - action="store_true", + action=argparse.BooleanOptionalAction, help="Collect files larger than 10MB in the Recycle Bin", - default=False, ) @module_arg( "--no-data-files", - action="store_true", + action=argparse.BooleanOptionalAction, help="Skip collection of data files in the Recycle Bin", - default=False, ) class RecycleBin(Module): DESC = "recycle bin metadata and data files" @@ -1318,7 +1316,7 @@ class Home(Module): @register_module("--ssh") -@module_arg("--private-keys", action="store_true", help="Add any private keys", default=False) +@module_arg("--private-keys", action=argparse.BooleanOptionalAction, help="Add any private keys") class SSH(Module): SPEC = [ ("glob", ".ssh/*", from_user_home), diff --git a/acquire/utils.py b/acquire/utils.py index 034eb807..67d1e191 100644 --- a/acquire/utils.py +++ b/acquire/utils.py @@ -81,12 +81,12 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg ) parser.add_argument( "--compress", - action="store_true", + action=argparse.BooleanOptionalAction, help="compress output (if supported by the output type)", ) parser.add_argument( "--encrypt", - action="store_true", + action=argparse.BooleanOptionalAction, help="encrypt output (if supported by the output type)", ) parser.add_argument( @@ -99,7 +99,7 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg ) parser.add_argument("--public-key", type=Path, help=argparse.SUPPRESS) parser.add_argument("-l", "--log", type=Path, help="log directory location") - parser.add_argument("--no-log", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--no-log", action=argparse.BooleanOptionalAction, help=argparse.SUPPRESS) parser.add_argument( "-L", "--loader", @@ -114,24 +114,28 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg parser.add_argument("-d", "--directory", action="append", help="acquire directory recursively") parser.add_argument("-g", "--glob", action="append", help="acquire files matching glob pattern") - parser.add_argument("--disable-report", action="store_true", help="disable acquisition report file") + parser.add_argument( + "--disable-report", action=argparse.BooleanOptionalAction, help="disable acquisition report file" + ) parser.add_argument("--child", help="only collect specific child") parser.add_argument( "--children", - action="store_true", + action=argparse.BooleanOptionalAction, help="collect all children in addition to main target", ) - parser.add_argument("--skip-parent", action="store_true", help="skip parent collection (when using --children)") + parser.add_argument( + "--skip-parent", action=argparse.BooleanOptionalAction, help="skip parent collection (when using --children)" + ) parser.add_argument( "--force-fallback", - action="store_true", + action=argparse.BooleanOptionalAction, help="force filesystem access directly through OS level. Only supported with target 'local'", ) parser.add_argument( "--fallback", - action="store_true", + action=argparse.BooleanOptionalAction, help=( "fallback to OS level filesystem access if filesystem type is not supported. " "Only supported with target 'local'" @@ -141,7 +145,7 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg parser.add_argument( "-u", "--auto-upload", - action="store_true", + action=argparse.BooleanOptionalAction, help="upload result files after collection", ) parser.add_argument( @@ -149,7 +153,7 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg nargs="+", help="upload specified files (all other acquire actions are ignored)", ) - parser.add_argument("--no-proxy", action="store_true", help="don't autodetect proxies") + parser.add_argument("--no-proxy", action=argparse.BooleanOptionalAction, help="don't autodetect proxies") for module_cls in modules.values(): for args, kwargs in module_cls.__cli_args__: diff --git a/tests/test_acquire_command.py b/tests/test_acquire_command.py index 73a20bcf..b1d2cb6e 100644 --- a/tests/test_acquire_command.py +++ b/tests/test_acquire_command.py @@ -1,11 +1,9 @@ from argparse import Namespace -from typing import List from unittest.mock import patch import pytest from acquire.acquire import ( - CONFIG, MODULES, PROFILES, VOLATILE, @@ -15,10 +13,11 @@ @pytest.fixture -def acquire_parser_args(config: List, argument_list: List) -> Namespace: - CONFIG["arguments"] = config +def acquire_parser_args(config: list[str], argument_list: list[str]) -> Namespace: + config_dict = {} + config_dict["arguments"] = config with patch("argparse._sys.argv", [""] + argument_list): - return parse_acquire_args(create_argument_parser(PROFILES, VOLATILE, MODULES), config=CONFIG)[0] + return parse_acquire_args(create_argument_parser(PROFILES, VOLATILE, MODULES), config=config_dict)[0] @pytest.mark.parametrize("config, argument_list", [([], [])]) @@ -39,3 +38,19 @@ def test_config_default_argument_override(acquire_parser_args): @pytest.mark.parametrize("config, argument_list", [([], ["target1", "target2"])]) def test_local_target_fallbactargets(acquire_parser_args): assert acquire_parser_args.targets == ["target1", "target2"] + + +@pytest.mark.parametrize( + "config, argument_list, arg_to_test, expected_value", + [ + (["--etc"], ["--no-etc"], "etc", False), + (["--no-etc"], ["--etc"], "etc", True), + (["--encrypt"], ["--no-encrypt"], "encrypt", False), + (["--no-encrypt"], ["--encrypt"], "encrypt", True), + (["--encrypt", "--ssh"], ["--no-ssh"], "ssh", False), + (["--private-keys"], ["--no-private-keys"], "private_keys", False), + (["--no-private-keys"], ["--private-keys"], "private_keys", True), + ], +) +def test_overwrites_optionals(acquire_parser_args: Namespace, arg_to_test: str, expected_value: bool): + assert getattr(acquire_parser_args, arg_to_test) is expected_value