Skip to content

Commit

Permalink
Map2 support varargs functions (#76)
Browse files Browse the repository at this point in the history
* updated map2 to support vararg functions

* updated map2 to support vararg functions

minor

updated map2 to support vararg functions

minor

* Fixed new bug found and added unit tests
  • Loading branch information
davidparks21 authored May 22, 2024
1 parent 38d3e74 commit 54d3d13
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 12 deletions.
46 changes: 35 additions & 11 deletions src/braingeneers/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,18 @@ def file_list(filepath: str) -> List[Tuple[str, str, int]]:


# Define the wrapper function as a top-level function
def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple) -> Any:
"""Internal wrapper function for map2 to handle fixed values and dynamic arguments."""
def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple, func_kwargs: Dict[str, Any]) -> Any:
"""Internal wrapper function for map2 to handle fixed values and dynamic arguments, including kwargs."""
# Merge fixed_values with provided arguments, aligning provided args with required_params
call_args = {**fixed_values, **dict(zip(required_params, args))}
return func(**call_args)
return func(**call_args, **func_kwargs)


def map2(func: Callable,
args: Iterable[Union[Tuple, object]] = None,
args: Iterable[Tuple[Any, ...]] = None,
kwargs: Iterable[Dict[str, Any]] = None,
fixed_values: dict = None,
parallelism: (bool, int) = True,
parallelism: Union[bool, int] = True,
use_multithreading: bool = False) -> List[object]:
"""
A universal multiprocessing version of the map function to simplify parallelizing code.
Expand All @@ -150,22 +151,35 @@ def map2(func: Callable,
def f(x, y):
print(x, y)
common_py_utils.map2(
common_utils.map2(
func=f,
args=[(1, 'yellow'), (2, 'yarn'), (3, 'yack')], # (x, y) arguments
parallelism=3, # use a 3 process multiprocessing pool
)
common_py_utils.map2(
common_utils.map2(
func=f,
args=[1, 2, 3], # x arguments
args=[1, 2, 3], # x arguments has multiple values to run
fixed_values=dict(y='yellow'), # y always is 'yellow'
parallelism=False, # Runs without parallelism which makes debugging exceptions easier
)
Usage example incorporating kwargs:
def myfunc(a, b, **kwargs):
print(a, b, kwargs.get('c'))
common_utils.map2(
func=myfunc,
args=[(1, 2), (3, 4)],
kwargs=[{'c': 50}, {'c': 100}],
)
:param func: a callable function
:param args: a list of arguments (if only 1 argument is left after fixed_values) or a list of tuples
(if multiple arguments are left after fixed_values)
:param kwargs: an iterable of dictionaries where each dictionary represents the keyword arguments to pass
to the function for each call. This parameter allows passing dynamic keyword arguments to the function.
the length of args and kwargs must be equal.
:param fixed_values: a dictionary with parameters that will stay the same for each call to func
:param parallelism: number of processes to use or boolean, default is # of CPU cores.
When parallelism==False or 1, this maps to itertools.starmap and does not use multiprocessing.
Expand All @@ -176,6 +190,9 @@ def f(x, y):
due to the GIL, threads are lighter weight than processes for some non cpu-intensive tasks.
:return: a list of the return values of func
"""
if args is not None and kwargs is not None:
assert len(args) == len(kwargs), \
f"args and kwargs must have the same length, found lengths: len(args)={len(args)} and len(kwargs)={len(kwargs)}"
assert isinstance(fixed_values, (dict, type(None)))
assert parallelism is False or isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer"
parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism
Expand All @@ -187,15 +204,22 @@ def f(x, y):
p.default == inspect.Parameter.empty and p.name not in fixed_values]

args_list = list(args or [])
kwargs_list = list(kwargs or [])
args_tuples = args_list if all(isinstance(a, tuple) for a in args_list) else [(a,) for a in args_list]

# Adjusted to handle cases where args might not be provided
call_parameters = list(zip(args_tuples, kwargs_list)) if args_tuples else [((), kw) for kw in kwargs_list]

if parallelism == 1:
result_iterator = map(lambda args: _map2_wrapper(fixed_values, required_params, func, args), args_tuples)
result_iterator = map(lambda params: _map2_wrapper(fixed_values, required_params, func, params[0], params[1]),
call_parameters)
else:
ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool
with ProcessOrThreadPool(parallelism) as pool:
result_iterator = pool.starmap(_map2_wrapper,
[(fixed_values, required_params, func, args) for args in args_tuples])
result_iterator = pool.starmap(
_map2_wrapper,
[(fixed_values, required_params, func, args, kw) for args, kw in call_parameters]
)

return list(result_iterator)

Expand Down
75 changes: 74 additions & 1 deletion src/braingeneers/utils/common_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest
from unittest.mock import patch, MagicMock
import common_utils
from common_utils import checkout, force_release_checkout
from common_utils import checkout, force_release_checkout, map2
from braingeneers.iot import messaging
import os
import tempfile
Expand Down Expand Up @@ -98,6 +98,79 @@ def test_checkout_context_manager_write_binary(self):
locked_obj.checkin(test_data)
self.mock_file.write.assert_called_once_with(test_data)

def test_with_pass_through_kwargs_handling(self):
"""Test map2 with a function accepting dynamic kwargs, specifically to check the handling of 'experiment_name'
passed through **kwargs, using the original signature for f_with_kwargs."""

def f_with_kwargs(cache_path: str, max_size_gb: int = 10, **kwargs):
# Simulate loading data where 'experiment_name' and other parameters are expected to come through **kwargs
self.assertTrue(isinstance(kwargs, dict), 'kwargs should be a dict')
self.assertFalse('kwargs' in kwargs)
return 'some data'

experiments = [{'experiment': 'exp1'}, {'experiment': 'exp2'}] # List of experiment names to be passed as individual kwargs
fixed_values = {
"cache_path": '/tmp/ephys_cache',
"max_size_gb": 50,
"metadata": {"some": "metadata"},
"channels": ["channel1"],
"length": -1,
}

# Execute the test under the assumption that map2 is supposed to handle 'experiment_name' in **kwargs correctly
map2(f_with_kwargs, kwargs=experiments, fixed_values=fixed_values, parallelism=False)
self.assertTrue(True) # If the test reaches this point, it has passed


class TestMap2Function(unittest.TestCase):
def test_with_kwargs_function_parallelism_false(self):
# Define a test function that takes a positional argument and arbitrary kwargs
def test_func(a, **kwargs):
return a + kwargs.get('increment', 0)

# Define the arguments and kwargs to pass to map2
args = [(1,), (2,), (3,)] # positional arguments
kwargs = [{'increment': 10}, {'increment': 20}, {'increment': 30}] # kwargs for each call

# Call map2 with the test function, args, kwargs, and parallelism=False
result = map2(
func=test_func,
args=args,
kwargs=kwargs,
parallelism=False
)

# Expected results after applying the function with the given args and kwargs
expected_results = [11, 22, 33]

# Assert that the actual result matches the expected result
self.assertEqual(result, expected_results)

def test_with_fixed_values_and_variable_kwargs_parallelism_false(self):
# Define a test function that takes fixed positional argument and arbitrary kwargs
def test_func(a, **kwargs):
return a + kwargs.get('increment', 0)

# Since 'a' is now a fixed value, we no longer need to provide it in args
args = [] # No positional arguments are passed here

# Define the kwargs to pass to map2, each dict represents kwargs for one call
kwargs = [{'increment': 10}, {'increment': 20}, {'increment': 30}]

# Call map2 with the test function, no args, variable kwargs, fixed_values containing 'a', and parallelism=False
result = map2(
func=test_func,
kwargs=kwargs,
fixed_values={'a': 1}, # 'a' is fixed for all calls
parallelism=False
)

# Expected results after applying the function with the fixed 'a' and given kwargs
expected_results = [11, 21, 31]

# Assert that the actual result matches the expected result
self.assertEqual(result, expected_results)


if __name__ == '__main__':
unittest.main()

0 comments on commit 54d3d13

Please sign in to comment.