Skip to content

Commit

Permalink
updated map2 to support vararg functions
Browse files Browse the repository at this point in the history
  • Loading branch information
davidparks21 committed Mar 12, 2024
1 parent 4e9173c commit 68d134e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 12 deletions.
49 changes: 37 additions & 12 deletions src/braingeneers/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,18 @@ def file_list(filepath: str) -> List[Tuple[str, str, int]]:


# Define the wrapper function as a top-level function
def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple) -> Any:
"""Internal wrapper function for map2 to handle fixed values and dynamic arguments."""
def _map2_wrapper(fixed_values: Dict[str, Any], required_params: List[str], func: Callable, args: Tuple, func_kwargs: Dict[str, Any]) -> Any:
"""Internal wrapper function for map2 to handle fixed values and dynamic arguments, including kwargs."""
# Merge fixed_values with provided arguments, aligning provided args with required_params
call_args = {**fixed_values, **dict(zip(required_params, args))}
return func(**call_args)
return func(**call_args, **func_kwargs)


def map2(func: Callable,
args: Iterable[Union[Tuple, object]] = None,
args: Iterable[Tuple[Any, ...]] = None,
kwargs: Iterable[Dict[str, Any]] = None,
fixed_values: dict = None,
parallelism: (bool, int) = True,
parallelism: Union[bool, int] = True,
use_multithreading: bool = False) -> List[object]:
"""
A universal multiprocessing version of the map function to simplify parallelizing code.
Expand All @@ -150,22 +151,35 @@ def map2(func: Callable,
def f(x, y):
print(x, y)
common_py_utils.map2(
common_utils.map2(
func=f,
args=[(1, 'yellow'), (2, 'yarn'), (3, 'yack')], # (x, y) arguments
parallelism=3, # use a 3 process multiprocessing pool
)
common_py_utils.map2(
common_utils.map2(
func=f,
args=[1, 2, 3], # x arguments
args=[1, 2, 3], # x arguments has multiple values to run
fixed_values=dict(y='yellow'), # y always is 'yellow'
parallelism=False, # Runs without parallelism which makes debugging exceptions easier
)
Usage example incorporating kwargs:
def myfunc(a, b, **kwargs):
print(a, b, kwargs.get('c'))
common_utils.map2(
func=myfunc,
args=[(1, 2), (3, 4)],
kwargs=[{'c': 50}, {'c': 100}],
)
:param func: a callable function
:param args: a list of arguments (if only 1 argument is left after fixed_values) or a list of tuples
(if multiple arguments are left after fixed_values)
:param kwargs: an iterable of dictionaries where each dictionary represents the keyword arguments to pass
to the function for each call. This parameter allows passing dynamic keyword arguments to the function.
the length of args and kwargs must be equal.
:param fixed_values: a dictionary with parameters that will stay the same for each call to func
:param parallelism: number of processes to use or boolean, default is # of CPU cores.
When parallelism==False or 1, this maps to itertools.starmap and does not use multiprocessing.
Expand All @@ -176,6 +190,9 @@ def f(x, y):
due to the GIL, threads are lighter weight than processes for some non cpu-intensive tasks.
:return: a list of the return values of func
"""
if args is not None and kwargs is not None:
assert len(args) == len(kwargs), \
"args and kwargs must have the same length, found lengths: len(args)={len(args)} and len(kwargs)={len(kwargs)}"
assert isinstance(fixed_values, (dict, type(None)))
assert parallelism is False or isinstance(parallelism, (bool, int)), "parallelism must be a boolean or an integer"
parallelism = multiprocessing.cpu_count() if parallelism is True else 1 if parallelism is False else parallelism
Expand All @@ -186,20 +203,28 @@ def f(x, y):
required_params = [p.name for p in func_signature.parameters.values() if
p.default == inspect.Parameter.empty and p.name not in fixed_values]

# Ensure args and kwargs are iterable or set them as empty lists if None
args_list = list(args or [])
kwargs_list = list(kwargs or [])
# Convert args to tuples if they're not already
args_tuples = args_list if all(isinstance(a, tuple) for a in args_list) else [(a,) for a in args_list]

# Prepare tuples of (fixed_values, required_params, func, args, kwargs) for each call
call_parameters = zip(args_tuples, kwargs_list) if kwargs_list else zip(args_tuples, [{}] * len(args_tuples))

if parallelism == 1:
result_iterator = map(lambda args: _map2_wrapper(fixed_values, required_params, func, args), args_tuples)
result_iterator = (map(lambda params: _map2_wrapper(fixed_values, required_params, func, params[0], params[1]),
call_parameters))
else:
ProcessOrThreadPool = multiprocessing.pool.ThreadPool if use_multithreading else multiprocessing.Pool
with ProcessOrThreadPool(parallelism) as pool:
result_iterator = pool.starmap(_map2_wrapper,
[(fixed_values, required_params, func, args) for args in args_tuples])
result_iterator = pool.starmap(
_map2_wrapper,
[(fixed_values, required_params, func, args, kw) for args, kw in call_parameters]
)

return list(result_iterator)


class checkout:
"""
A context manager for atomically checking out a file from S3 for reading or writing.
Expand Down
23 changes: 23 additions & 0 deletions src/braingeneers/utils/common_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,29 @@ def test_checkout_context_manager_write_binary(self):
locked_obj.checkin(test_data)
self.mock_file.write.assert_called_once_with(test_data)

def test_with_pass_through_kwargs_handling(self):
"""Test map2 with a function accepting dynamic kwargs, specifically to check the handling of 'experiment_name'
passed through **kwargs, using the original signature for f_with_kwargs."""

def f_with_kwargs(cache_path: str, max_size_gb: int = 10, **kwargs):
# Simulate loading data where 'experiment_name' and other parameters are expected to come through **kwargs
self.assertTrue(isinstance(kwargs, dict), 'kwargs should be a dict')
self.assertFalse('kwargs' in kwargs)
return 'some data'

experiments = [{'experiment': 'exp1'}, {'experiment': 'exp2'}] # List of experiment names to be passed as individual kwargs
fixed_values = {
"cache_path": '/tmp/ephys_cache',
"max_size_gb": 50,
"metadata": {"some": "metadata"},
"channels": ["channel1"],
"length": -1,
}

# Execute the test under the assumption that map2 is supposed to handle 'experiment_name' in **kwargs correctly
map2(f_with_kwargs, kwargs=experiments, fixed_values=fixed_values, parallelism=False)
self.assertTrue(True) # If the test reaches this point, it has passed


if __name__ == '__main__':
unittest.main()

0 comments on commit 68d134e

Please sign in to comment.