@@ -94,7 +94,15 @@ def _evaluate_dataframe(
94
94
95
95
96
96
def _evaluate_basic (
97
- to_evaluate , input_cols , evaluation_function , func_args , func_kwargs , mapper , task_ids , db
97
+ to_evaluate ,
98
+ input_cols ,
99
+ evaluation_function ,
100
+ func_args ,
101
+ func_kwargs ,
102
+ mapper ,
103
+ task_ids ,
104
+ db ,
105
+ progress_bar = True ,
98
106
):
99
107
res = []
100
108
# Setup the function to apply to the data
@@ -109,8 +117,11 @@ def _evaluate_basic(
109
117
arg_list = list (to_evaluate .loc [task_ids , input_cols ].to_dict ("index" ).items ())
110
118
111
119
try :
120
+ tasks = mapper (eval_func , arg_list )
121
+ if progress_bar :
122
+ tasks = tqdm (tasks , total = len (task_ids ))
112
123
# Compute and collect the results
113
- for task_id , result , exception in tqdm ( mapper ( eval_func , arg_list ), total = len ( task_ids )) :
124
+ for task_id , result , exception in tasks :
114
125
res .append (dict ({"df_index" : task_id , "exception" : exception }, ** result ))
115
126
116
127
# Save the results into the DB
@@ -163,6 +174,7 @@ def evaluate(
163
174
func_args = None ,
164
175
func_kwargs = None ,
165
176
shuffle_rows = True ,
177
+ progress_bar = True ,
166
178
** mapper_kwargs ,
167
179
):
168
180
"""Evaluate and save results in a sqlite database on the fly and return dataframe.
@@ -185,12 +197,14 @@ def evaluate(
185
197
func_args (list): the arguments to pass to the evaluation_function.
186
198
func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
187
199
shuffle_rows (bool): if :obj:`True`, it will shuffle the rows before computing the results.
200
+ progress_bar (bool): if :obj:`True`, a progress bar will be displayed during computation.
188
201
**mapper_kwargs: the keyword arguments are passed to the get_mapper() method of the
189
202
:class:`ParallelFactory` instance.
190
203
191
204
Return:
192
205
pandas.DataFrame: dataframe with new columns containing the computed results.
193
206
"""
207
+ # pylint: disable=too-many-branches
194
208
# Initialize the parallel factory
195
209
if isinstance (parallel_factory , str ) or parallel_factory is None :
196
210
parallel_factory = init_parallel_factory (parallel_factory )
@@ -243,6 +257,8 @@ def evaluate(
243
257
return to_evaluate
244
258
245
259
# Get the factory mapper
260
+ if isinstance (parallel_factory , DaskDataFrameFactory ):
261
+ mapper_kwargs ["progress_bar" ] = progress_bar
246
262
mapper = parallel_factory .get_mapper (** mapper_kwargs )
247
263
248
264
if isinstance (parallel_factory , DaskDataFrameFactory ):
@@ -267,6 +283,7 @@ def evaluate(
267
283
mapper ,
268
284
task_ids ,
269
285
db ,
286
+ progress_bar ,
270
287
)
271
288
to_evaluate .loc [res_df .index , res_df .columns ] = res_df
272
289
0 commit comments