Skip to content

Commit d802540

Browse files
Feat: Can disable progress bar (#4)
1 parent 3e8998b commit d802540

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

bluepyparallel/evaluator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,15 @@ def _evaluate_dataframe(
9494

9595

9696
def _evaluate_basic(
97-
to_evaluate, input_cols, evaluation_function, func_args, func_kwargs, mapper, task_ids, db
97+
to_evaluate,
98+
input_cols,
99+
evaluation_function,
100+
func_args,
101+
func_kwargs,
102+
mapper,
103+
task_ids,
104+
db,
105+
progress_bar=True,
98106
):
99107
res = []
100108
# Setup the function to apply to the data
@@ -109,8 +117,11 @@ def _evaluate_basic(
109117
arg_list = list(to_evaluate.loc[task_ids, input_cols].to_dict("index").items())
110118

111119
try:
120+
tasks = mapper(eval_func, arg_list)
121+
if progress_bar:
122+
tasks = tqdm(tasks, total=len(task_ids))
112123
# Compute and collect the results
113-
for task_id, result, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
124+
for task_id, result, exception in tasks:
114125
res.append(dict({"df_index": task_id, "exception": exception}, **result))
115126

116127
# Save the results into the DB
@@ -163,6 +174,7 @@ def evaluate(
163174
func_args=None,
164175
func_kwargs=None,
165176
shuffle_rows=True,
177+
progress_bar=True,
166178
**mapper_kwargs,
167179
):
168180
"""Evaluate and save results in a sqlite database on the fly and return dataframe.
@@ -185,12 +197,14 @@ def evaluate(
185197
func_args (list): the arguments to pass to the evaluation_function.
186198
func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
187199
shuffle_rows (bool): if :obj:`True`, it will shuffle the rows before computing the results.
200+
progress_bar (bool): if :obj:`True`, a progress bar will be displayed during computation.
188201
**mapper_kwargs: the keyword arguments are passed to the get_mapper() method of the
189202
:class:`ParallelFactory` instance.
190203
191204
Return:
192205
pandas.DataFrame: dataframe with new columns containing the computed results.
193206
"""
207+
# pylint: disable=too-many-branches
194208
# Initialize the parallel factory
195209
if isinstance(parallel_factory, str) or parallel_factory is None:
196210
parallel_factory = init_parallel_factory(parallel_factory)
@@ -243,6 +257,8 @@ def evaluate(
243257
return to_evaluate
244258

245259
# Get the factory mapper
260+
if isinstance(parallel_factory, DaskDataFrameFactory):
261+
mapper_kwargs["progress_bar"] = progress_bar
246262
mapper = parallel_factory.get_mapper(**mapper_kwargs)
247263

248264
if isinstance(parallel_factory, DaskDataFrameFactory):
@@ -267,6 +283,7 @@ def evaluate(
267283
mapper,
268284
task_ids,
269285
db,
286+
progress_bar,
270287
)
271288
to_evaluate.loc[res_df.index, res_df.columns] = res_df
272289

bluepyparallel/parallel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,13 @@ def _with_batches(self, *args, **kwargs):
379379
yield tmp
380380

381381
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
382-
"""Get a Dask mapper."""
382+
"""Get a Dask mapper.
383+
384+
If ``progress_bar=True`` is passed as keyword argument, a progress bar will be displayed
385+
during computation.
386+
"""
383387
self._chunksize_to_kwargs(chunk_size, kwargs, label="chunksize")
388+
progress_bar = kwargs.pop("progress_bar", True)
384389
if not kwargs.get("chunksize"):
385390
kwargs["npartitions"] = self.nb_processes or 1
386391

@@ -389,7 +394,8 @@ def _dask_df_mapper(func, iterable):
389394
df = pd.DataFrame(iterable)
390395
ddf = dd.from_pandas(df, **kwargs)
391396
future = ddf.apply(func, meta=meta, axis=1).persist()
392-
progress(future)
397+
if progress_bar:
398+
progress(future)
393399
# Put into a list because of the 'yield from' in _with_batches
394400
return [future.compute()]
395401

tests/test_evaluator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,18 @@ class TestEvaluate:
9898
"""Test the ``bluepyparallel.evaluator.evaluate`` function."""
9999

100100
@pytest.mark.parametrize("with_sql", [True, False])
101-
def test_evaluate(self, input_df, new_columns, expected_df, db_url, with_sql, parallel_factory):
101+
@pytest.mark.parametrize("progress_bar", [True, False])
102+
def test_evaluate(
103+
self, input_df, new_columns, expected_df, db_url, with_sql, progress_bar, parallel_factory
104+
):
102105
"""Test evaluator on a trivial example."""
103106
result_df = evaluate(
104107
input_df,
105108
_evaluation_function,
106109
new_columns,
107110
parallel_factory=parallel_factory,
108111
db_url=db_url if with_sql else None,
112+
progress_bar=progress_bar,
109113
)
110114
if not with_sql:
111115
remove_sql_cols(expected_df)

0 commit comments

Comments
 (0)