From 71729fa21406d146f0ba0941aeb735567e3e6b89 Mon Sep 17 00:00:00 2001 From: lawhead Date: Mon, 2 Dec 2024 15:51:03 -0800 Subject: [PATCH] Added sim parameter for sampler --- bcipy/simulator/task/task_runner.py | 30 ++++++++++++++++------------- bcipy/simulator/ui/cli.py | 13 ++++++++----- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/bcipy/simulator/task/task_runner.py b/bcipy/simulator/task/task_runner.py index 7aa4f88d..3bc5e9c1 100644 --- a/bcipy/simulator/task/task_runner.py +++ b/bcipy/simulator/task/task_runner.py @@ -2,10 +2,12 @@ import argparse import logging -from glob import glob +import sys from pathlib import Path -from bcipy.simulator.data.sampler import TargetNontargetSampler +# pylint: disable=unused-import +# flake8: noqa +from bcipy.simulator.data.sampler import Sampler, TargetNontargetSampler from bcipy.simulator.task.copy_phrase import SimulatorCopyPhraseTask from bcipy.simulator.task.task_factory import TaskFactory from bcipy.simulator.ui import cli @@ -17,6 +19,11 @@ logger = logging.getLogger(TOP_LEVEL_LOGGER_NAME) +def classify(classname): + """Convert the given class name to a class.""" + return getattr(sys.modules[__name__], classname) + + class TaskRunner(): """Responsible for executing a task a given number of times.""" @@ -67,7 +74,6 @@ def main(): required=False, action='append', help=data_help) - parser.add_argument('-g', '--glob_pattern', help=glob_help, default="*") parser.add_argument( "-m", "--model_path", @@ -85,6 +91,12 @@ def main(): required=False, default=1, help="Number of times to run the simulation") + parser.add_argument("-s", + "--sampler", + type=str, + required=False, + default='TargetNontargetSampler', + help="Sampling strategy") parser.add_argument("-o", "--output", type=Path, @@ -100,18 +112,10 @@ def main(): if args.interactive: task_factory = cli.main(sim_args) else: - source_dirs = sim_args['data_folder'] - if len(source_dirs) == 1: - parent = source_dirs[0] - source_dirs = [ - Path(d) - for d in glob(str(Path(parent, sim_args['glob_pattern']))) - if Path(d).is_dir() - ] task_factory = TaskFactory(params_path=sim_args['parameters'], - source_dirs=source_dirs, + source_dirs=sim_args['data_folder'], signal_model_paths=sim_args['model_path'], - sampling_strategy=TargetNontargetSampler, + sampling_strategy=classify(sim_args['sampler']), task=SimulatorCopyPhraseTask) runner = TaskRunner(save_dir=sim_dir, diff --git a/bcipy/simulator/ui/cli.py b/bcipy/simulator/ui/cli.py index fce1d77c..8784482d 100644 --- a/bcipy/simulator/ui/cli.py +++ b/bcipy/simulator/ui/cli.py @@ -196,7 +196,7 @@ def choose_sampling_strategy() -> Type[Sampler]: options = {klass.__name__: klass for klass in classes} selected = Prompt.ask("Choose a sampling strategy", choices=options.keys(), - default='TargetNonTargetSampler') + default='TargetNontargetSampler') return options[selected] @@ -228,12 +228,13 @@ def get_acq_mode(params_path: str): return params['acq_mode'].get('value', 'EEG') -def command(params: str, models: List[str], source_dirs: List[str]) -> str: +def command(params: str, models: List[str], source_dirs: List[str], sampler: Type[Sampler]) -> str: """Command equivalent to to the result of the interactive selection of simulator inputs.""" model_args = ' '.join([f"-m {path}" for path in models]) dir_args = ' '.join(f"-d {source}" for source in source_dirs) - return f"bcipy-sim -p {params} {model_args} {dir_args}" + sampler_args = f"-s {sampler.__name__}" + return f"bcipy-sim -p {params} {model_args} {sampler_args} {dir_args}" def main(args: Dict[str, Any]) -> TaskFactory: @@ -250,11 +251,13 @@ def main(args: Dict[str, Any]) -> TaskFactory: str(path) for path in select_data(args.get('data_folder', None)) ] - print(command(params, model_paths, source_dirs)) + sampler = choose_sampling_strategy() + + print(command(params, model_paths, source_dirs, sampler)) return TaskFactory(params_path=params, source_dirs=source_dirs, signal_model_paths=model_paths, - sampling_strategy=TargetNontargetSampler, + sampling_strategy=sampler, task=SimulatorCopyPhraseTask)