Skip to content

Commit

Permalink
Add optional boolean value to all store_true instances
Browse files Browse the repository at this point in the history
  • Loading branch information
Miauwkeru committed Jul 8, 2024
1 parent 8741da3 commit f77392f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
10 changes: 4 additions & 6 deletions acquire/acquire.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def wrapper(module_cls: type[Module]) -> type[Module]:

desc = module_cls.DESC or name
kwargs["help"] = f"acquire {desc}"
kwargs["action"] = "store_true"
kwargs["action"] = argparse.BooleanOptionalAction
kwargs["dest"] = name.lower()
module_cls.__modname__ = name

Expand Down Expand Up @@ -661,15 +661,13 @@ def recyclebin_filter(path: fsutil.TargetPath) -> bool:
@register_module("--recyclebin")
@module_arg(
"--large-files",
action="store_true",
action=argparse.BooleanOptionalAction,
help="Collect files larger than 10MB in the Recycle Bin",
default=False,
)
@module_arg(
"--no-data-files",
action="store_true",
action=argparse.BooleanOptionalAction,
help="Skip collection of data files in the Recycle Bin",
default=False,
)
class RecycleBin(Module):
DESC = "recycle bin metadata and data files"
Expand Down Expand Up @@ -1318,7 +1316,7 @@ class Home(Module):


@register_module("--ssh")
@module_arg("--private-keys", action="store_true", help="Add any private keys", default=False)
@module_arg("--private-keys", action=argparse.BooleanOptionalAction, help="Add any private keys")
class SSH(Module):
SPEC = [
("glob", ".ssh/*", from_user_home),
Expand Down
24 changes: 14 additions & 10 deletions acquire/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg
)
parser.add_argument(
"--compress",
action="store_true",
action=argparse.BooleanOptionalAction,
help="compress output (if supported by the output type)",
)
parser.add_argument(
"--encrypt",
action="store_true",
action=argparse.BooleanOptionalAction,
help="encrypt output (if supported by the output type)",
)
parser.add_argument(
Expand All @@ -99,7 +99,7 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg
)
parser.add_argument("--public-key", type=Path, help=argparse.SUPPRESS)
parser.add_argument("-l", "--log", type=Path, help="log directory location")
parser.add_argument("--no-log", action="store_true", help=argparse.SUPPRESS)
parser.add_argument("--no-log", action=argparse.BooleanOptionalAction, help=argparse.SUPPRESS)
parser.add_argument(
"-L",
"--loader",
Expand All @@ -114,24 +114,28 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg
parser.add_argument("-d", "--directory", action="append", help="acquire directory recursively")
parser.add_argument("-g", "--glob", action="append", help="acquire files matching glob pattern")

parser.add_argument("--disable-report", action="store_true", help="disable acquisition report file")
parser.add_argument(
"--disable-report", action=argparse.BooleanOptionalAction, help="disable acquisition report file"
)

parser.add_argument("--child", help="only collect specific child")
parser.add_argument(
"--children",
action="store_true",
action=argparse.BooleanOptionalAction,
help="collect all children in addition to main target",
)
parser.add_argument("--skip-parent", action="store_true", help="skip parent collection (when using --children)")
parser.add_argument(
"--skip-parent", action=argparse.BooleanOptionalAction, help="skip parent collection (when using --children)"
)

parser.add_argument(
"--force-fallback",
action="store_true",
action=argparse.BooleanOptionalAction,
help="force filesystem access directly through OS level. Only supported with target 'local'",
)
parser.add_argument(
"--fallback",
action="store_true",
action=argparse.BooleanOptionalAction,
help=(
"fallback to OS level filesystem access if filesystem type is not supported. "
"Only supported with target 'local'"
Expand All @@ -141,15 +145,15 @@ def create_argument_parser(profiles: dict, volatile: dict, modules: dict) -> arg
parser.add_argument(
"-u",
"--auto-upload",
action="store_true",
action=argparse.BooleanOptionalAction,
help="upload result files after collection",
)
parser.add_argument(
"--upload",
nargs="+",
help="upload specified files (all other acquire actions are ignored)",
)
parser.add_argument("--no-proxy", action="store_true", help="don't autodetect proxies")
parser.add_argument("--no-proxy", action=argparse.BooleanOptionalAction, help="don't autodetect proxies")

for module_cls in modules.values():
for args, kwargs in module_cls.__cli_args__:
Expand Down
25 changes: 20 additions & 5 deletions tests/test_acquire_command.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from argparse import Namespace
from typing import List
from unittest.mock import patch

import pytest

from acquire.acquire import (
CONFIG,
MODULES,
PROFILES,
VOLATILE,
Expand All @@ -15,10 +13,11 @@


@pytest.fixture
def acquire_parser_args(config: List, argument_list: List) -> Namespace:
CONFIG["arguments"] = config
def acquire_parser_args(config: list[str], argument_list: list[str]) -> Namespace:
config_dict = {}
config_dict["arguments"] = config
with patch("argparse._sys.argv", [""] + argument_list):
return parse_acquire_args(create_argument_parser(PROFILES, VOLATILE, MODULES), config=CONFIG)[0]
return parse_acquire_args(create_argument_parser(PROFILES, VOLATILE, MODULES), config=config_dict)[0]


@pytest.mark.parametrize("config, argument_list", [([], [])])
Expand All @@ -39,3 +38,19 @@ def test_config_default_argument_override(acquire_parser_args):
@pytest.mark.parametrize("config, argument_list", [([], ["target1", "target2"])])
def test_local_target_fallbactargets(acquire_parser_args):
assert acquire_parser_args.targets == ["target1", "target2"]


@pytest.mark.parametrize(
"config, argument_list, arg_to_test, expected_value",
[
(["--etc"], ["--no-etc"], "etc", False),
(["--no-etc"], ["--etc"], "etc", True),
(["--encrypt"], ["--no-encrypt"], "encrypt", False),
(["--no-encrypt"], ["--encrypt"], "encrypt", True),
(["--encrypt", "--ssh"], ["--no-ssh"], "ssh", False),
(["--private-keys"], ["--no-private-keys"], "private_keys", False),
(["--no-private-keys"], ["--private-keys"], "private_keys", True),
],
)
def test_overwrites_optionals(acquire_parser_args: Namespace, arg_to_test: str, expected_value: bool):
assert getattr(acquire_parser_args, arg_to_test) is expected_value

0 comments on commit f77392f

Please sign in to comment.