diff --git a/src/braingeneers/utils/common_utils.py b/src/braingeneers/utils/common_utils.py index c0a4031..6ebb9a1 100644 --- a/src/braingeneers/utils/common_utils.py +++ b/src/braingeneers/utils/common_utils.py @@ -124,17 +124,18 @@ def file_list(filepath: str) -> List[Tuple[str, str, int]]: # Define the wrapper function as a top-level function -def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple) -> Any: - """Internal wrapper function for map2 to handle fixed values and dynamic arguments.""" +def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple, func_kwargs: Dict[str, Any]) -> Any: + """Internal wrapper function for map2 to handle fixed values and dynamic arguments, including kwargs.""" # Merge fixed_values with provided arguments, aligning provided args with required_params call_args = {**fixed_values, **dict(zip(required_params, args))} - return func(**call_args) + return func(**call_args, **func_kwargs) def map2(func: Callable, - args: Iterable[Union[Tuple, object]] = None, + args: Iterable[Tuple[Any, ...]] = None, + kwargs: Iterable[Dict[str, Any]] = None, fixed_values: dict = None, - parallelism: (bool, int) = True, + parallelism: Union[bool, int] = True, use_multithreading: bool = False) -> List[object]: """ A universal multiprocessing version of the map function to simplify parallelizing code. @@ -150,22 +151,35 @@ def map2(func: Callable, def f(x, y): print(x, y) - common_py_utils.map2( + common_utils.map2( func=f, args=[(1, 'yellow'), (2, 'yarn'), (3, 'yack')], # (x, y) arguments parallelism=3, # use a 3 process multiprocessing pool ) - common_py_utils.map2( + common_utils.map2( func=f, - args=[1, 2, 3], # x arguments + args=[1, 2, 3], # x arguments has multiple values to run fixed_values=dict(y='yellow'), # y always is 'yellow' parallelism=False, # Runs without parallelism which makes debugging exceptions easier ) + Usage example incorporating kwargs: + def myfunc(a, b, **kwargs): + print(a, b, kwargs.get('c')) + + common_utils.map2( + func=myfunc, + args=[(1, 2), (3, 4)], + kwargs=[{'c': 50}, {'c': 100}], + ) + :param func: a callable function :param args: a list of arguments (if only 1 argument is left after fixed_values) or a list of tuples (if multiple arguments are left after fixed_values) + :param kwargs: an iterable of dictionaries where each dictionary represents the keyword arguments to pass + to the function for each call. This parameter allows passing dynamic keyword arguments to the function. + the length of args and kwargs must be equal. :param fixed_values: a dictionary with parameters that will stay the same for each call to func :param parallelism: number of processes to use or boolean, default is # of CPU cores. When parallelism==False or 1, this maps to itertools.starmap and does not use multiprocessing. @@ -176,6 +190,9 @@ def f(x, y): due to the GIL, threads are lighter weight than processes for some non cpu-intensive tasks. :return: a list of the return values of func """ + if args is not None and kwargs is not None: + assert len(args) == len(kwargs), \ + f"args and kwargs must have the same length, found lengths: len(args)={len(args)} and len(kwargs)={len(kwargs)}" assert isinstance(fixed_values, (dict, type(None))) assert parallelism is False or isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer" parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism @@ -187,15 +204,22 @@ def f(x, y): p.default == inspect.Parameter.empty and p.name not in fixed_values] args_list = list(args or []) + kwargs_list = list(kwargs or []) args_tuples = args_list if all(isinstance(a, tuple) for a in args_list) else [(a,) for a in args_list] + # Adjusted to handle cases where args might not be provided + call_parameters = list(zip(args_tuples, kwargs_list)) if args_tuples else [((), kw) for kw in kwargs_list] + if parallelism == 1: - result_iterator = map(lambda args: _map2_wrapper(fixed_values, required_params, func, args), args_tuples) + result_iterator = map(lambda params: _map2_wrapper(fixed_values, required_params, func, params[0], params[1]), + call_parameters) else: ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool with ProcessOrThreadPool(parallelism) as pool: - result_iterator = pool.starmap(_map2_wrapper, - [(fixed_values, required_params, func, args) for args in args_tuples]) + result_iterator = pool.starmap( + _map2_wrapper, + [(fixed_values, required_params, func, args, kw) for args, kw in call_parameters] + ) return list(result_iterator) diff --git a/src/braingeneers/utils/common_utils_test.py b/src/braingeneers/utils/common_utils_test.py index be1cc90..8544da7 100644 --- a/src/braingeneers/utils/common_utils_test.py +++ b/src/braingeneers/utils/common_utils_test.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import patch, MagicMock import common_utils -from common_utils import checkout, force_release_checkout +from common_utils import checkout, force_release_checkout, map2 from braingeneers.iot import messaging import os import tempfile @@ -98,6 +98,79 @@ def test_checkout_context_manager_write_binary(self): locked_obj.checkin(test_data) self.mock_file.write.assert_called_once_with(test_data) + def test_with_pass_through_kwargs_handling(self): + """Test map2 with a function accepting dynamic kwargs, specifically to check the handling of 'experiment_name' + passed through **kwargs, using the original signature for f_with_kwargs.""" + + def f_with_kwargs(cache_path: str, max_size_gb: int = 10, **kwargs): + # Simulate loading data where 'experiment_name' and other parameters are expected to come through **kwargs + self.assertTrue(isinstance(kwargs, dict), 'kwargs should be a dict') + self.assertFalse('kwargs' in kwargs) + return 'some data' + + experiments = [{'experiment': 'exp1'}, {'experiment': 'exp2'}] # List of experiment names to be passed as individual kwargs + fixed_values = { + "cache_path": '/tmp/ephys_cache', + "max_size_gb": 50, + "metadata": {"some": "metadata"}, + "channels": ["channel1"], + "length": -1, + } + + # Execute the test under the assumption that map2 is supposed to handle 'experiment_name' in **kwargs correctly + map2(f_with_kwargs, kwargs=experiments, fixed_values=fixed_values, parallelism=False) + self.assertTrue(True) # If the test reaches this point, it has passed + + +class TestMap2Function(unittest.TestCase): + def test_with_kwargs_function_parallelism_false(self): + # Define a test function that takes a positional argument and arbitrary kwargs + def test_func(a, **kwargs): + return a + kwargs.get('increment', 0) + + # Define the arguments and kwargs to pass to map2 + args = [(1,), (2,), (3,)] # positional arguments + kwargs = [{'increment': 10}, {'increment': 20}, {'increment': 30}] # kwargs for each call + + # Call map2 with the test function, args, kwargs, and parallelism=False + result = map2( + func=test_func, + args=args, + kwargs=kwargs, + parallelism=False + ) + + # Expected results after applying the function with the given args and kwargs + expected_results = [11, 22, 33] + + # Assert that the actual result matches the expected result + self.assertEqual(result, expected_results) + + def test_with_fixed_values_and_variable_kwargs_parallelism_false(self): + # Define a test function that takes fixed positional argument and arbitrary kwargs + def test_func(a, **kwargs): + return a + kwargs.get('increment', 0) + + # Since 'a' is now a fixed value, we no longer need to provide it in args + args = [] # No positional arguments are passed here + + # Define the kwargs to pass to map2, each dict represents kwargs for one call + kwargs = [{'increment': 10}, {'increment': 20}, {'increment': 30}] + + # Call map2 with the test function, no args, variable kwargs, fixed_values containing 'a', and parallelism=False + result = map2( + func=test_func, + kwargs=kwargs, + fixed_values={'a': 1}, # 'a' is fixed for all calls + parallelism=False + ) + + # Expected results after applying the function with the fixed 'a' and given kwargs + expected_results = [11, 21, 31] + + # Assert that the actual result matches the expected result + self.assertEqual(result, expected_results) + if __name__ == '__main__': unittest.main()