From 68d134e5a5e661ef3f7d4743cbd2ca76d93b28a7 Mon Sep 17 00:00:00 2001 From: davidparks21 Date: Tue, 12 Mar 2024 11:56:57 -0700 Subject: [PATCH] updated map2 to support vararg functions --- src/braingeneers/utils/common_utils.py | 49 ++++++++++++++++----- src/braingeneers/utils/common_utils_test.py | 23 ++++++++++ 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/src/braingeneers/utils/common_utils.py b/src/braingeneers/utils/common_utils.py index c0a4031..9d2af20 100644 --- a/src/braingeneers/utils/common_utils.py +++ b/src/braingeneers/utils/common_utils.py @@ -124,17 +124,18 @@ def file_list(filepath: str) -> List[Tuple[str, str, int]]: # Define the wrapper function as a top-level function -def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple) -> Any: - """Internal wrapper function for map2 to handle fixed values and dynamic arguments.""" +def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple, func_kwargs: Dict[str, Any]) -> Any: + """Internal wrapper function for map2 to handle fixed values and dynamic arguments, including kwargs.""" # Merge fixed_values with provided arguments, aligning provided args with required_params call_args = {**fixed_values, **dict(zip(required_params, args))} - return func(**call_args) + return func(**call_args, **func_kwargs) def map2(func: Callable, - args: Iterable[Union[Tuple, object]] = None, + args: Iterable[Tuple[Any, ...]] = None, + kwargs: Iterable[Dict[str, Any]] = None, fixed_values: dict = None, - parallelism: (bool, int) = True, + parallelism: Union[bool, int] = True, use_multithreading: bool = False) -> List[object]: """ A universal multiprocessing version of the map function to simplify parallelizing code. @@ -150,22 +151,35 @@ def map2(func: Callable, def f(x, y): print(x, y) - common_py_utils.map2( + common_utils.map2( func=f, args=[(1, 'yellow'), (2, 'yarn'), (3, 'yack')], # (x, y) arguments parallelism=3, # use a 3 process multiprocessing pool ) - common_py_utils.map2( + common_utils.map2( func=f, - args=[1, 2, 3], # x arguments + args=[1, 2, 3], # x arguments has multiple values to run fixed_values=dict(y='yellow'), # y always is 'yellow' parallelism=False, # Runs without parallelism which makes debugging exceptions easier ) + Usage example incorporating kwargs: + def myfunc(a, b, **kwargs): + print(a, b, kwargs.get('c')) + + common_utils.map2( + func=myfunc, + args=[(1, 2), (3, 4)], + kwargs=[{'c': 50}, {'c': 100}], + ) + :param func: a callable function :param args: a list of arguments (if only 1 argument is left after fixed_values) or a list of tuples (if multiple arguments are left after fixed_values) + :param kwargs: an iterable of dictionaries where each dictionary represents the keyword arguments to pass + to the function for each call. This parameter allows passing dynamic keyword arguments to the function. + the length of args and kwargs must be equal. :param fixed_values: a dictionary with parameters that will stay the same for each call to func :param parallelism: number of processes to use or boolean, default is # of CPU cores. When parallelism==False or 1, this maps to itertools.starmap and does not use multiprocessing. @@ -176,6 +190,9 @@ def f(x, y): due to the GIL, threads are lighter weight than processes for some non cpu-intensive tasks. :return: a list of the return values of func """ + if args is not None and kwargs is not None: + assert len(args) == len(kwargs), \ + "args and kwargs must have the same length, found lengths: len(args)={len(args)} and len(kwargs)={len(kwargs)}" assert isinstance(fixed_values, (dict, type(None))) assert parallelism is False or isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer" parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism @@ -186,20 +203,28 @@ def f(x, y): required_params = [p.name for p in func_signature.parameters.values() if p.default == inspect.Parameter.empty and p.name not in fixed_values] + # Ensure args and kwargs are iterable or set them as empty lists if None args_list = list(args or []) + kwargs_list = list(kwargs or []) + # Convert args to tuples if they're not already args_tuples = args_list if all(isinstance(a, tuple) for a in args_list) else [(a,) for a in args_list] + # Prepare tuples of (fixed_values, required_params, func, args, kwargs) for each call + call_parameters = zip(args_tuples, kwargs_list) if kwargs_list else zip(args_tuples, [{}] * len(args_tuples)) + if parallelism == 1: - result_iterator = map(lambda args: _map2_wrapper(fixed_values, required_params, func, args), args_tuples) + result_iterator = (map(lambda params: _map2_wrapper(fixed_values, required_params, func, params[0], params[1]), + call_parameters)) else: ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool with ProcessOrThreadPool(parallelism) as pool: - result_iterator = pool.starmap(_map2_wrapper, - [(fixed_values, required_params, func, args) for args in args_tuples]) + result_iterator = pool.starmap( + _map2_wrapper, + [(fixed_values, required_params, func, args, kw) for args, kw in call_parameters] + ) return list(result_iterator) - class checkout: """ A context manager for atomically checking out a file from S3 for reading or writing. diff --git a/src/braingeneers/utils/common_utils_test.py b/src/braingeneers/utils/common_utils_test.py index be1cc90..471bf1d 100644 --- a/src/braingeneers/utils/common_utils_test.py +++ b/src/braingeneers/utils/common_utils_test.py @@ -98,6 +98,29 @@ def test_checkout_context_manager_write_binary(self): locked_obj.checkin(test_data) self.mock_file.write.assert_called_once_with(test_data) + def test_with_pass_through_kwargs_handling(self): + """Test map2 with a function accepting dynamic kwargs, specifically to check the handling of 'experiment_name' + passed through **kwargs, using the original signature for f_with_kwargs.""" + + def f_with_kwargs(cache_path: str, max_size_gb: int = 10, **kwargs): + # Simulate loading data where 'experiment_name' and other parameters are expected to come through **kwargs + self.assertTrue(isinstance(kwargs, dict), 'kwargs should be a dict') + self.assertFalse('kwargs' in kwargs) + return 'some data' + + experiments = [{'experiment': 'exp1'}, {'experiment': 'exp2'}] # List of experiment names to be passed as individual kwargs + fixed_values = { + "cache_path": '/tmp/ephys_cache', + "max_size_gb": 50, + "metadata": {"some": "metadata"}, + "channels": ["channel1"], + "length": -1, + } + + # Execute the test under the assumption that map2 is supposed to handle 'experiment_name' in **kwargs correctly + map2(f_with_kwargs, kwargs=experiments, fixed_values=fixed_values, parallelism=False) + self.assertTrue(True) # If the test reaches this point, it has passed + if __name__ == '__main__': unittest.main()