From d3df08f709c1fad70e40ae4d910498d93bca4fc5 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 30 Mar 2024 07:57:34 -0500 Subject: [PATCH 01/11] Replace fastprogress with rich --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- conda-envs/environment-jax.yml | 2 +- conda-envs/environment-test.yml | 2 +- conda-envs/windows-environment-dev.yml | 2 +- conda-envs/windows-environment-test.yml | 2 +- pymc/sampling/forward.py | 41 +++--- pymc/sampling/mcmc.py | 29 +++-- pymc/sampling/parallel.py | 94 ++++++++------ pymc/sampling/population.py | 66 ++++++---- pymc/smc/sampling.py | 73 +++++------ pymc/stats/log_density.py | 7 +- pymc/tuning/starting.py | 31 ++--- pymc/variational/inference.py | 165 ++++++++++++------------ requirements-dev.txt | 2 +- requirements.txt | 2 +- 16 files changed, 274 insertions(+), 248 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index a2cf7c25d3a..a3f4a41a8c5 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - numpy>=1.15.0 - pandas>=0.24.0 @@ -28,6 +27,7 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 +- rich>=13.7.1 - sphinx-copybutton - sphinx-design - sphinx-notfound-page diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 86227038372..d50328df26c 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -8,12 +8,12 @@ dependencies: - arviz>=0.13.0 - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - numpy>=1.15.0 - pandas>=0.24.0 - pip - pytensor>=2.19,<2.20 - python-graphviz +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for docs build diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 0986f43046e..34379048015 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 # Jaxlib version must not be greater than jax version! - blackjax>=1.0.0 @@ -24,6 +23,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 8272cca2396..6c0c2a0b61f 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - jax - libblas=*=*mkl @@ -20,6 +19,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 25fdeb419ce..91df7bfbac5 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - numpy>=1.15.0 - pandas>=0.24.0 @@ -17,6 +16,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for dev, testing and docs build diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 900e3e227e6..aaa958e985c 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -9,7 +9,6 @@ dependencies: - blas - cachetools>=4.2.1 - cloudpickle -- fastprogress>=0.2.0 - h5py>=2.7 - libpython - mkl-service>=2.3.0 @@ -20,6 +19,7 @@ dependencies: - pytensor>=2.19,<2.20 - python-graphviz - networkx +- rich>=13.7.1 - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index b63b9850a26..221bd6198c2 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -30,7 +30,6 @@ import xarray from arviz import InferenceData -from fastprogress.fastprogress import progress_bar from pytensor import tensor as pt from pytensor.graph.basic import ( Apply, @@ -46,6 +45,7 @@ RandomStateSharedVariable, ) from pytensor.tensor.sharedvar import SharedVariable +from rich.progress import Progress from typing_extensions import TypeAlias import pymc as pm @@ -796,10 +796,6 @@ def sample_posterior_predictive( else: vars_ = model.observed_RVs + observed_dependent_deterministics(model) - indices = np.arange(samples) - if progressbar: - indices = progress_bar(indices, total=samples, display=progressbar) - vars_to_sample = list(get_default_varnames(vars_, include_transformed=False)) if not vars_to_sample: @@ -834,25 +830,28 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - for idx in indices: - if nchain > 1: - # the trace object will either be a MultiTrace (and have _straces)... - if hasattr(_trace, "_straces"): - chain_idx, point_idx = np.divmod(idx, len_trace) - chain_idx = chain_idx % nchain - param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) - # ... or a PointList + with Progress() as progress: + task = progress.add_task("Sampling", visible=progressbar, total=samples) + for idx in np.arange(samples): + progress.update(task, advance=1) + if nchain > 1: + # the trace object will either be a MultiTrace (and have _straces)... + if hasattr(_trace, "_straces"): + chain_idx, point_idx = np.divmod(idx, len_trace) + chain_idx = chain_idx % nchain + param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx) + # ... or a PointList + else: + param = cast(PointList, _trace)[idx % (len_trace * nchain)] + # there's only a single chain, but the index might hit it multiple times if + # the number of indices is greater than the length of the trace. else: - param = cast(PointList, _trace)[idx % (len_trace * nchain)] - # there's only a single chain, but the index might hit it multiple times if - # the number of indices is greater than the length of the trace. - else: - param = _trace[idx % len_trace] + param = _trace[idx % len_trace] - values = sampler_fn(**param) + values = sampler_fn(**param) - for k, v in zip(vars_, values): - ppc_trace_t.insert(k.name, v, idx) + for k, v in zip(vars_, values): + ppc_trace_t.insert(k.name, v, idx) except KeyboardInterrupt: pass diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index dd97f78c884..ea38f5626f3 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -34,8 +34,8 @@ from arviz import InferenceData, dict_to_dataset from arviz.data.base import make_attrs -from fastprogress.fastprogress import progress_bar from pytensor.graph.basic import Variable +from rich.progress import Progress from typing_extensions import Protocol, TypeAlias import pymc as pm @@ -1026,19 +1026,20 @@ def _sample( ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - if progressbar: - sampling = progress_bar(sampling_gen, total=draws, display=progressbar) - sampling.comment = _desc.format(**_pbar_data) - else: - sampling = sampling_gen - try: - for it, diverging in enumerate(sampling): - if it >= skip_first and diverging: - _pbar_data["divergences"] += 1 - if progressbar: - sampling.comment = _desc.format(**_pbar_data) - except KeyboardInterrupt: - pass + # if progressbar: + # sampling = progress_bar(sampling_gen, total=draws, display=progressbar) + # sampling.comment = _desc.format(**_pbar_data) + # else: + # sampling = sampling_gen + with Progress() as progress: + try: + task = progress.add_task(_desc.format(**_pbar_data), total=draws) + for it, diverging in enumerate(sampling_gen): + if it >= skip_first and diverging: + _pbar_data["divergences"] += 1 + progress.update(task, advance=1) + except KeyboardInterrupt: + pass def _iter_sample( diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 430c361cac1..a2919a086c2 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -26,7 +26,7 @@ import cloudpickle import numpy as np -from fastprogress.fastprogress import progress_bar +from rich import progress from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError @@ -420,14 +420,21 @@ def __init__( self._in_context = False - self._progress = None + self._progress = progress.Progress( + "[progress.description]{task.description}", + progress.BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + progress.TimeRemainingColumn(), + ) + self._show_progress = progressbar self._divergences = 0 - self._total_draws = 0 + self._completed_draws = 0 + self._total_draws = chains * (draws + tune) self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences" self._chains = chains - if progressbar: - self._progress = progress_bar(range(chains * (draws + tune)), display=progressbar) - self._progress.comment = self._desc.format(self) + # if progressbar: + # self._progress = progress_bar(range(chains * (draws + tune)), display=progressbar) + # self._progress.comment = self._desc.format(self) def _make_active(self): while self._inactive and len(self._active) < self._max_active: @@ -441,37 +448,50 @@ def __iter__(self): raise ValueError("Use ParallelSampler as context manager.") self._make_active() - if self._active and self._progress: - self._progress.update(self._total_draws) - - while self._active: - draw = ProcessAdapter.recv_draw(self._active) - proc, is_last, draw, tuning, stats = draw - self._total_draws += 1 - if not tuning and stats and stats[0].get("diverging"): - self._divergences += 1 - if self._progress: - self._progress.comment = self._desc.format(self) - if self._progress: - self._progress.update(self._total_draws) - - if is_last: - proc.join() - self._active.remove(proc) - self._finished.append(proc) - self._make_active() - - # We could also yield proc.shared_point_view directly, - # and only call proc.write_next() after the yield returns. - # This seems to be faster overally though, as the worker - # loses less time waiting. - point = {name: val.copy() for name, val in proc.shared_point_view.items()} - - # Already called for new proc in _make_active - if not is_last: - proc.write_next() - - yield Draw(proc.chain, is_last, draw, tuning, stats, point) + with self._progress as progress: + task = progress.add_task( + self._desc.format(self), + completed=self._completed_draws, + total=self._total_draws, + visible=self._show_progress, + ) + + # if self._active and self._progress: + # self._progress.update(self._total_draws) + # progress.update( + # task, divergences=self._divergences + # ) + + while self._active: + draw = ProcessAdapter.recv_draw(self._active) + proc, is_last, draw, tuning, stats = draw + self._completed_draws += 1 + if not tuning and stats and stats[0].get("diverging"): + self._divergences += 1 + progress.update( + task, + completed=self._completed_draws, + total=self._total_draws, + description=self._desc.format(self), + ) + + if is_last: + proc.join() + self._active.remove(proc) + self._finished.append(proc) + self._make_active() + + # We could also yield proc.shared_point_view directly, + # and only call proc.write_next() after the yield returns. + # This seems to be faster overally though, as the worker + # loses less time waiting. + point = {name: val.copy() for name, val in proc.shared_point_view.items()} + + # Already called for new proc in _make_active + if not is_last: + proc.write_next() + + yield Draw(proc.chain, is_last, draw, tuning, stats, point) def __enter__(self): self._in_context = True diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index c38b90599b2..2a0db2ecfa8 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from fastprogress.fastprogress import progress_bar +from rich.progress import BarColumn, Progress, TimeRemainingColumn from typing_extensions import TypeAlias from pymc.backends.base import BaseTrace @@ -101,11 +101,12 @@ def _sample_population( progressbar=progressbar, ) - if progressbar: - sampling = progress_bar(sampling, total=draws, display=progressbar) + with Progress() as progress: + task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar) + + for _ in sampling: + progress.update(task, advance=1) - for i in sampling: - pass return @@ -166,6 +167,7 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): self._primary_ends = [] self._processes = [] self._steppers = steppers + self._progress = None if parallelize: try: # configure a child process for each stepper @@ -174,25 +176,34 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): ) import multiprocessing - for c, stepper in ( - enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) - ): - secondary_end, primary_end = multiprocessing.Pipe() - stepper_dumps = cloudpickle.dumps(stepper, protocol=4) - process = multiprocessing.Process( - target=self.__class__._run_secondary, - args=(c, stepper_dumps, secondary_end), - name=f"ChainWalker{c}", - ) - # we want the child process to exit if the parent is terminated - process.daemon = True - # Starting the process might fail and takes time. - # By doing it in the constructor, the sampling progress bar - # will not be confused by the process start. - process.start() - self._primary_ends.append(primary_end) - self._processes.append(process) - self.is_parallelized = True + with Progress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeRemainingColumn(), + ) as self._progress: + for c, stepper in enumerate(steppers): + # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) + # ): + task = self._progress.add_task( + description=f"Chain {c}", visible=progressbar + ) + secondary_end, primary_end = multiprocessing.Pipe() + stepper_dumps = cloudpickle.dumps(stepper, protocol=4) + process = multiprocessing.Process( + target=self.__class__._run_secondary, + args=(c, stepper_dumps, secondary_end, task, self._progress), + name=f"ChainWalker{c}", + ) + # we want the child process to exit if the parent is terminated + process.daemon = True + # Starting the process might fail and takes time. + # By doing it in the constructor, the sampling progress bar + # will not be confused by the process start. + process.start() + self._primary_ends.append(primary_end) + self._processes.append(process) + self.is_parallelized = True except Exception: _log.info( "Population parallelization failed. " @@ -222,7 +233,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return @staticmethod - def _run_secondary(c, stepper_dumps, secondary_end): + def _run_secondary(c, stepper_dumps, secondary_end, task, progress): """The method is started on a separate process to perform stepping of a chain. Parameters @@ -233,6 +244,10 @@ def _run_secondary(c, stepper_dumps, secondary_end): a step method such as CompoundStep secondary_end : multiprocessing.connection.PipeConnection This is our connection to the main process + task : progress.Task + The progress task for this chain + progress : progress.Progress + The progress bar """ # re-seed each child process to make them unique np.random.seed(None) @@ -259,6 +274,7 @@ def _run_secondary(c, stepper_dumps, secondary_end): for popstep in population_steppers: popstep.population = population update = stepper.step(population[c]) + progress.advance(task) secondary_end.send(update) except Exception: _log.exception(f"ChainWalker{c}") diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 2ea3800acec..81ec9fd1007 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -25,7 +25,9 @@ import numpy as np from arviz import InferenceData -from fastprogress.fastprogress import force_console_behavior, progress_bar + +# from fastprogress.fastprogress import force_console_behavior, progress_bar +from rich.progress import Progress, TextColumn, get_default_columns import pymc @@ -310,7 +312,8 @@ def _sample_smc_int( model, random_seed, chain, - progressbar=None, + pbar, + pbar_visible, **kernel_kwargs, ): """Run one SMC instance.""" @@ -337,9 +340,9 @@ def _sample_smc_int( **kernel_kwargs, ) - if progressbar: - progressbar.comment = f"{getattr(progressbar, 'base_comment', '')} Stage: 0 Beta: 0" - progressbar.update_bar(getattr(progressbar, "offset", 0) + 0) + task = pbar.add_task( + f"Chain: {chain + 1}", total=100, comment="Stage: 0 Beta: 0", visible=pbar_visible + ) smc._initialize_kernel() smc.setup_kernel() @@ -349,11 +352,7 @@ def _sample_smc_int( while smc.beta < 1: smc.update_beta_and_weights() - if progressbar: - progressbar.comment = ( - f"{getattr(progressbar, 'base_comment', '')} Stage: {stage} Beta: {smc.beta:.3f}" - ) - progressbar.update_bar(getattr(progressbar, "offset", 0) + int(smc.beta * 100)) + pbar.update(task, advance=1, comment=f"Stage: {stage} Beta: {smc.beta:.3f}") smc.resample() smc.tune() @@ -376,38 +375,36 @@ def _sample_smc_int( def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores): - # fastprogress HTML progress bar does not support multiprocessing - _, progress_bar = force_console_behavior() - pbar = progress_bar((), total=100, display=progressbar) - pbar.update(0) - pbars = [pbar] + [None] * (chains - 1) - - pool = mp.Pool(cores) - - # "manually" (de)serialize params before/after multiprocessing - params = tuple(cloudpickle.dumps(p) for p in params) - kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} - results = _starmap_with_kwargs( - pool, - to_run, - [(*params, random_seed[chain], chain, pbars[chain]) for chain in range(chains)], - repeat(kernel_kwargs), - ) - results = tuple(cloudpickle.loads(r) for r in results) - pool.close() - pool.join() - return results + with Progress( + *get_default_columns(), + TextColumn("{task.comment}"), + ) as pbar: + pool = mp.Pool(cores) + + # "manually" (de)serialize params before/after multiprocessing + params = tuple(cloudpickle.dumps(p) for p in params) + kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} + results = _starmap_with_kwargs( + pool, + to_run, + [(*params, random_seed[chain], chain, pbar, progressbar) for chain in range(chains)], + repeat(kernel_kwargs), + ) + results = tuple(cloudpickle.loads(r) for r in results) + pool.close() + pool.join() + return results def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): results = [] - pbar = progress_bar((), total=100 * chains, display=progressbar) - pbar.update(0) - for chain in range(chains): - pbar.offset = 100 * chain - pbar.base_comment = f"Chain: {chain + 1}/{chains}" - results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs)) - return results + with Progress( + *get_default_columns(), + TextColumn("{task.comment}"), + ) as pbar: + for chain in range(chains): + results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs)) + return results def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index e72e2445799..1435560437b 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -15,7 +15,7 @@ from typing import Optional, cast from arviz import InferenceData, dict_to_dataset -from fastprogress import progress_bar +from rich.progress import track import pymc @@ -169,9 +169,10 @@ def compute_log_density( n_pts = len(posterior_pts) logdens_dict = _DefaultTrace(n_pts) - indices = range(n_pts) if progressbar: - indices = progress_bar(indices, total=n_pts, display=progressbar) + indices = track(range(n_pts), description="Computing log density") + else: + indices = range(n_pts) for idx in indices: logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 8db885057b8..903ed627d87 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -27,9 +27,11 @@ import numpy as np import pytensor.gradient as tg -from fastprogress.fastprogress import ProgressBar, progress_bar + +# from fastprogress.fastprogress import ProgressBar, progress_bar from numpy import isfinite from pytensor import Variable +from rich import progress from scipy.optimize import minimize import pymc as pm @@ -174,12 +176,8 @@ def find_MAP( if isinstance(e, StopIteration): pm._log.info(e) finally: - last_v = cost_func.n_eval - if progressbar: - assert isinstance(cost_func.progress, ProgressBar) - cost_func.progress.total = last_v - cost_func.progress.update(last_v) - print(file=sys.stdout) + cost_func.progress.update(cost_func.task, total=cost_func.n_eval) + print(file=sys.stdout) mx0 = RaveledVars(mx0, x0.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) @@ -212,11 +210,8 @@ def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=No self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}" self.previous_x = None self.progressbar = progressbar - if progressbar: - self.progress = progress_bar(range(maxeval), total=maxeval, display=progressbar) - self.progress.update(0) - else: - self.progress = range(maxeval) + self.progress = progress.Progress() + self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar) def __call__(self, x): neg_value = np.float64(self.logp_func(pm.floatX(x))) @@ -232,16 +227,14 @@ def __call__(self, x): grad = None if self.n_eval % 10 == 0: - self.update_progress_desc(neg_value, grad) + self.progress.update(self.task, description=self.update_progress_desc(neg_value, grad)) if self.n_eval > self.maxeval: - self.update_progress_desc(neg_value, grad) + self.progress.update(self.task, description=self.update_progress_desc(neg_value, grad)) raise StopIteration self.n_eval += 1 - if self.progressbar: - assert isinstance(self.progress, ProgressBar) - self.progress.update_bar(self.n_eval) + self.progress.advance(self.task, 1) if self.use_gradient: return value, grad @@ -251,7 +244,7 @@ def __call__(self, x): def update_progress_desc(self, neg_value: float, grad: np.float64 = None) -> None: if self.progressbar: if grad is None: - self.progress.comment = self.desc.format(neg_value) + return self.desc.format(neg_value) else: norm_grad = np.linalg.norm(grad) - self.progress.comment = self.desc.format(neg_value, norm_grad) + return self.desc.format(neg_value, norm_grad) diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 6ee5815d145..ecf635fec95 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -18,7 +18,7 @@ import numpy as np -from fastprogress.fastprogress import progress_bar +from rich.progress import Progress, track import pymc as pm @@ -83,9 +83,8 @@ def run_profiling(self, n=1000, score=None, **kwargs): fn_kwargs = kwargs.pop("fn_kwargs", dict()) fn_kwargs["profile"] = True step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs) - progress = progress_bar(range(n)) try: - for _ in progress: + for _ in track(range(n)): step_func() except KeyboardInterrupt: pass @@ -136,14 +135,11 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): callbacks = [] score = self._maybe_score(score) step_func = self.objective.step_function(score=score, **kwargs) - if progressbar: - progress = progress_bar(range(n), display=progressbar) - else: - progress = range(n) + if score: - state = self._iterate_with_loss(0, n, step_func, progress, callbacks) + state = self._iterate_with_loss(0, n, step_func, progressbar, callbacks) else: - state = self._iterate_without_loss(0, n, step_func, progress, callbacks) + state = self._iterate_without_loss(0, n, step_func, progressbar, callbacks) # hack to allow pm.fit() access to loss hist self.approx.hist = self.hist @@ -151,43 +147,46 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): return self.approx - def _iterate_without_loss(self, s, _, step_func, progress, callbacks): + def _iterate_without_loss(self, s, n, step_func, progressbar, callbacks): i = 0 try: - for i in progress: - step_func() - current_param = self.approx.params[0].get_value() - if np.isnan(current_param).any(): - name_slc = [] - tmp_hold = list(range(current_param.size)) - for varname, slice_info in self.approx.groups[0].ordering.items(): - slclen = len(tmp_hold[slice_info[1]]) - for j in range(slclen): - name_slc.append((varname, j)) - index = np.where(np.isnan(current_param))[0] - errmsg = ["NaN occurred in optimization. "] - suggest_solution = ( - "Try tracking this parameter: " - "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" - ) - try: - for ii in index: - errmsg.append( - "The current approximation of RV `{}`.ravel()[{}]" - " is NaN.".format(*name_slc[ii]) - ) - errmsg.append(suggest_solution) - except IndexError: - pass - raise FloatingPointError("\n".join(errmsg)) - for callback in callbacks: - callback(self.approx, None, i + s + 1) + with Progress() as progress: + task = progress.add_task("Fitting", total=n, visible=progressbar) + for i in range(n): + step_func() + progress.update(task, advance=1) + current_param = self.approx.params[0].get_value() + if np.isnan(current_param).any(): + name_slc = [] + tmp_hold = list(range(current_param.size)) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) + for j in range(slclen): + name_slc.append((varname, j)) + index = np.where(np.isnan(current_param))[0] + errmsg = ["NaN occurred in optimization. "] + suggest_solution = ( + "Try tracking this parameter: " + "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" + ) + try: + for ii in index: + errmsg.append( + "The current approximation of RV `{}`.ravel()[{}]" + " is NaN.".format(*name_slc[ii]) + ) + errmsg.append(suggest_solution) + except IndexError: + pass + raise FloatingPointError("\n".join(errmsg)) + for callback in callbacks: + callback(self.approx, None, i + s + 1) except (KeyboardInterrupt, StopIteration) as e: if isinstance(e, StopIteration): logger.info(str(e)) return State(i + s, step=step_func, callbacks=callbacks, score=False) - def _iterate_with_loss(self, s, n, step_func, progress, callbacks): + def _iterate_with_loss(self, s, n, step_func, progressbar, callbacks): def _infmean(input_array): """Return the mean of the finite values of the array""" input_array = input_array[np.isfinite(input_array)].astype("float64") @@ -200,44 +199,48 @@ def _infmean(input_array): scores[:] = np.nan i = 0 try: - for i in progress: - e = step_func() - if np.isnan(e): - scores = scores[:i] - self.hist = np.concatenate([self.hist, scores]) - current_param = self.approx.params[0].get_value() - name_slc = [] - tmp_hold = list(range(current_param.size)) - for varname, slice_info in self.approx.groups[0].ordering.items(): - slclen = len(tmp_hold[slice_info[1]]) - for j in range(slclen): - name_slc.append((varname, j)) - index = np.where(np.isnan(current_param))[0] - errmsg = ["NaN occurred in optimization. "] - suggest_solution = ( - "Try tracking this parameter: " - "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" - ) - try: - for ii in index: - errmsg.append( - "The current approximation of RV `{}`.ravel()[{}]" - " is NaN.".format(*name_slc[ii]) - ) - errmsg.append(suggest_solution) - except IndexError: - pass - raise FloatingPointError("\n".join(errmsg)) - scores[i] = e - if i % 10 == 0: - avg_loss = _infmean(scores[max(0, i - 1000) : i + 1]) - if hasattr(progress, "comment"): - progress.comment = f"Average Loss = {avg_loss:,.5g}" - avg_loss = scores[max(0, i - 1000) : i + 1].mean() - if hasattr(progress, "comment"): - progress.comment = f"Average Loss = {avg_loss:,.5g}" - for callback in callbacks: - callback(self.approx, scores[: i + 1], i + s + 1) + with Progress( + *Progress.get_default_columns(), + Progress.TextColumn("{task.loss}"), + ) as progress: + task = progress.add_task("Fitting:", total=n, visible=progressbar) + for i in range(n): + e = step_func() + progress.update(task, advance=1) + if np.isnan(e): + scores = scores[:i] + self.hist = np.concatenate([self.hist, scores]) + current_param = self.approx.params[0].get_value() + name_slc = [] + tmp_hold = list(range(current_param.size)) + for varname, slice_info in self.approx.groups[0].ordering.items(): + slclen = len(tmp_hold[slice_info[1]]) + for j in range(slclen): + name_slc.append((varname, j)) + index = np.where(np.isnan(current_param))[0] + errmsg = ["NaN occurred in optimization. "] + suggest_solution = ( + "Try tracking this parameter: " + "http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters" + ) + try: + for ii in index: + errmsg.append( + "The current approximation of RV `{}`.ravel()[{}]" + " is NaN.".format(*name_slc[ii]) + ) + errmsg.append(suggest_solution) + except IndexError: + pass + raise FloatingPointError("\n".join(errmsg)) + scores[i] = e + if i % 10 == 0: + avg_loss = _infmean(scores[max(0, i - 1000) : i + 1]) + progress.update(task, loss=f"Average Loss = {avg_loss:,.5g}") + avg_loss = scores[max(0, i - 1000) : i + 1].mean() + progress.update(task, loss=f"Average Loss = {avg_loss:,.5g}") + for callback in callbacks: + callback(self.approx, scores[: i + 1], i + s + 1) except (KeyboardInterrupt, StopIteration) as e: # pragma: no cover # do not print log on the same line scores = scores[:i] @@ -266,14 +269,10 @@ def refine(self, n, progressbar=True): if self.state is None: raise TypeError("Need to call `.fit` first") i, step, callbacks, score = self.state - if progressbar: - progress = progress_bar(range(n), display=progressbar) - else: - progress = range(n) # This is a guess at what progress_bar(n) does. if score: - state = self._iterate_with_loss(i, n, step, progress, callbacks) + state = self._iterate_with_loss(i, n, step, progressbar, callbacks) else: - state = self._iterate_without_loss(i, n, step, progress, callbacks) + state = self._iterate_without_loss(i, n, step, progressbar, callbacks) self.state = state diff --git a/requirements-dev.txt b/requirements-dev.txt index ddcf9ded9b9..56077f3a6b1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,7 +4,6 @@ arviz>=0.13.0 cachetools>=4.2.1 cloudpickle -fastprogress>=0.2.0 git+https://github.com/pymc-devs/pymc-sphinx-theme h5py>=2.7 ipython>=7.16 @@ -21,6 +20,7 @@ pre-commit>=2.8.0 pytensor>=2.19,<2.20 pytest-cov>=2.5 pytest>=3.0 +rich>=13.7.1 scipy>=1.4.1 sphinx-copybutton sphinx-design diff --git a/requirements.txt b/requirements.txt index 0bc21049c15..370dcbd41e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ arviz>=0.13.0 cachetools>=4.2.1 cloudpickle -fastprogress>=0.2.0 numpy>=1.15.0 pandas>=0.24.0 pytensor>=2.19,<2.20 +rich>=13.7.1 scipy>=1.4.1 typing-extensions>=3.7.4 From 0a7b02ca116570da16590de6bf51c1c9df961efb Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 30 Mar 2024 08:20:07 -0500 Subject: [PATCH 02/11] Bugfixes for ADVI progress bars --- pymc/smc/sampling.py | 6 +++--- pymc/variational/inference.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 81ec9fd1007..edc92b402f2 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -27,7 +27,7 @@ from arviz import InferenceData # from fastprogress.fastprogress import force_console_behavior, progress_bar -from rich.progress import Progress, TextColumn, get_default_columns +from rich.progress import Progress, TextColumn import pymc @@ -376,7 +376,7 @@ def _sample_smc_int( def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores): with Progress( - *get_default_columns(), + *Progress.get_default_columns(), TextColumn("{task.comment}"), ) as pbar: pool = mp.Pool(cores) @@ -399,7 +399,7 @@ def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): results = [] with Progress( - *get_default_columns(), + *Progress.get_default_columns(), TextColumn("{task.comment}"), ) as pbar: for chain in range(chains): diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index ecf635fec95..dff402a86e4 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -18,7 +18,7 @@ import numpy as np -from rich.progress import Progress, track +from rich.progress import Progress, TextColumn, track import pymc as pm @@ -201,9 +201,9 @@ def _infmean(input_array): try: with Progress( *Progress.get_default_columns(), - Progress.TextColumn("{task.loss}"), + TextColumn("{task.fields[loss]}"), ) as progress: - task = progress.add_task("Fitting:", total=n, visible=progressbar) + task = progress.add_task("Fitting:", total=n, visible=progressbar, loss="") for i in range(n): e = step_func() progress.update(task, advance=1) From 7634baa14513f0f6b0a2dc853fe38a43c00e32e6 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Sat, 30 Mar 2024 08:37:04 -0500 Subject: [PATCH 03/11] Bugfixes for MAP progress bars --- pymc/tuning/starting.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 903ed627d87..c2a7ddf2021 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -31,7 +31,7 @@ # from fastprogress.fastprogress import ProgressBar, progress_bar from numpy import isfinite from pytensor import Variable -from rich import progress +from rich.progress import Progress, TextColumn from scipy.optimize import minimize import pymc as pm @@ -166,18 +166,19 @@ def find_MAP( cost_func = CostFuncWrapper(maxeval, progressbar, logp_func) compute_gradient = False - try: - opt_result = minimize( - cost_func, x0.data, method=method, jac=compute_gradient, *args, **kwargs - ) - mx0 = opt_result["x"] # r -> opt_result - except (KeyboardInterrupt, StopIteration) as e: - mx0, opt_result = cost_func.previous_x, None - if isinstance(e, StopIteration): - pm._log.info(e) - finally: - cost_func.progress.update(cost_func.task, total=cost_func.n_eval) - print(file=sys.stdout) + with cost_func.progress: + try: + opt_result = minimize( + cost_func, x0.data, method=method, jac=compute_gradient, *args, **kwargs + ) + mx0 = opt_result["x"] # r -> opt_result + except (KeyboardInterrupt, StopIteration) as e: + mx0, opt_result = cost_func.previous_x, None + if isinstance(e, StopIteration): + pm._log.info(e) + finally: + cost_func.progress.update(cost_func.task, total=cost_func.n_eval) + print(file=sys.stdout) mx0 = RaveledVars(mx0, x0.point_map_info) unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) @@ -210,8 +211,11 @@ def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=No self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}" self.previous_x = None self.progressbar = progressbar - self.progress = progress.Progress() - self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar) + self.progress = Progress( + *Progress.get_default_columns(), + TextColumn("{task.fields[loss]}"), + ) + self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") def __call__(self, x): neg_value = np.float64(self.logp_func(pm.floatX(x))) @@ -227,10 +231,10 @@ def __call__(self, x): grad = None if self.n_eval % 10 == 0: - self.progress.update(self.task, description=self.update_progress_desc(neg_value, grad)) + self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad)) if self.n_eval > self.maxeval: - self.progress.update(self.task, description=self.update_progress_desc(neg_value, grad)) + self.progress.update(self.task, loss=self.update_progress_desc(neg_value, grad)) raise StopIteration self.n_eval += 1 From 9d1702d8a273815b0f25c2caa024b1352c07deaa Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 1 Apr 2024 09:52:55 -0500 Subject: [PATCH 04/11] Fixed final update to progress bar --- pymc/sampling/mcmc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ea38f5626f3..18f724e8e33 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1038,6 +1038,7 @@ def _sample( if it >= skip_first and diverging: _pbar_data["divergences"] += 1 progress.update(task, advance=1) + progress.update(task, advance=1, completed=True) except KeyboardInterrupt: pass From 32c9c11402c0e0e10316b0e09dd77a96e3e18128 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 1 Apr 2024 11:18:31 -0500 Subject: [PATCH 05/11] SMC progress bar working --- pymc/sampling/forward.py | 5 +- pymc/smc/sampling.py | 110 ++++++++++++++++++--------------------- 2 files changed, 53 insertions(+), 62 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 221bd6198c2..910a8707f0b 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -831,9 +831,7 @@ def sample_posterior_predictive( ppc_trace_t = _DefaultTrace(samples) try: with Progress() as progress: - task = progress.add_task("Sampling", visible=progressbar, total=samples) - for idx in np.arange(samples): - progress.update(task, advance=1) + for idx in progress.track(np.arange(samples), description="Sampling ..."): if nchain > 1: # the trace object will either be a MultiTrace (and have _straces)... if hasattr(_trace, "_straces"): @@ -852,6 +850,7 @@ def sample_posterior_predictive( for k, v in zip(vars_, values): ppc_trace_t.insert(k.name, v, idx) + except KeyboardInterrupt: pass diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index edc92b402f2..366b068e8a4 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -13,12 +13,12 @@ # limitations under the License. import logging -import multiprocessing as mp +import multiprocessing import time import warnings from collections import defaultdict -from itertools import repeat +from concurrent.futures import ProcessPoolExecutor from typing import Any, Optional, Union import cloudpickle @@ -27,7 +27,7 @@ from arviz import InferenceData # from fastprogress.fastprogress import force_console_behavior, progress_bar -from rich.progress import Progress, TextColumn +from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn import pymc @@ -211,14 +211,8 @@ def sample_smc( t1 = time.time() - if cores > 1: - results = run_chains_parallel( - chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs, cores - ) - else: - results = run_chains_sequential( - chains, progressbar, _sample_smc_int, params, random_seed, kernel_kwargs - ) + results = run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores) + ( traces, sample_stats, @@ -312,8 +306,8 @@ def _sample_smc_int( model, random_seed, chain, - pbar, - pbar_visible, + progress_dict, + task_id, **kernel_kwargs, ): """Run one SMC instance.""" @@ -340,10 +334,6 @@ def _sample_smc_int( **kernel_kwargs, ) - task = pbar.add_task( - f"Chain: {chain + 1}", total=100, comment="Stage: 0 Beta: 0", visible=pbar_visible - ) - smc._initialize_kernel() smc.setup_kernel() @@ -352,7 +342,7 @@ def _sample_smc_int( while smc.beta < 1: smc.update_beta_and_weights() - pbar.update(task, advance=1, comment=f"Stage: {stage} Beta: {smc.beta:.3f}") + progress_dict[task_id] = {"stage": stage, "beta": smc.beta} smc.resample() smc.tune() @@ -374,45 +364,47 @@ def _sample_smc_int( return results -def run_chains_parallel(chains, progressbar, to_run, params, random_seed, kernel_kwargs, cores): - with Progress( - *Progress.get_default_columns(), - TextColumn("{task.comment}"), - ) as pbar: - pool = mp.Pool(cores) - - # "manually" (de)serialize params before/after multiprocessing - params = tuple(cloudpickle.dumps(p) for p in params) - kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} - results = _starmap_with_kwargs( - pool, - to_run, - [(*params, random_seed[chain], chain, pbar, progressbar) for chain in range(chains)], - repeat(kernel_kwargs), - ) - results = tuple(cloudpickle.loads(r) for r in results) - pool.close() - pool.join() - return results - - -def run_chains_sequential(chains, progressbar, to_run, params, random_seed, kernel_kwargs): - results = [] +def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): with Progress( - *Progress.get_default_columns(), - TextColumn("{task.comment}"), - ) as pbar: - for chain in range(chains): - results.append(to_run(*params, random_seed[chain], chain, pbar, **kernel_kwargs)) - return results - - -def _starmap_with_kwargs(pool, fn, args_iter, kwargs_iter): - # Helper function to allow kwargs with Pool.starmap - # Copied from https://stackoverflow.com/a/53173433/13311693 - args_for_starmap = zip(repeat(fn), args_iter, kwargs_iter) - return pool.starmap(_apply_args_and_kwargs, args_for_starmap) - - -def _apply_args_and_kwargs(fn, args, kwargs): - return fn(*args, **kwargs) + TextColumn("{task.description}"), + SpinnerColumn(), + TimeElapsedColumn(), + TextColumn("{task.fields[status]}"), + ) as progress: + futures = [] # keep track of the jobs + with multiprocessing.Manager() as manager: + # this is the key - we share some state between our + # main process and our worker functions + _progress = manager.dict() + + # "manually" (de)serialize params before/after multiprocessing + params = tuple(cloudpickle.dumps(p) for p in params) + kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} + + with ProcessPoolExecutor(max_workers=cores) as executor: + for c in range(chains): # iterate over the jobs we need to run + # set visible false so we don't have a lot of bars all at once: + task_id = progress.add_task( + f"Chain {c}", status="Stage: 0 Beta: 0", visible=progressbar + ) + futures.append( + executor.submit( + _sample_smc_int, + *params, + random_seed[c], + c, + _progress, + task_id, + **kernel_kwargs, + ) + ) + + # monitor the progress: + while sum([future.done() for future in futures]) < len(futures): + for task_id, update_data in _progress.items(): + stage = update_data["stage"] + beta = update_data["beta"] + # update the progress bar for this task: + progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id) + + return tuple(cloudpickle.loads(r.result()) for r in futures) From 58c2b17ec7689620f45f88dfa50ef620c71d788b Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 1 Apr 2024 12:14:24 -0500 Subject: [PATCH 06/11] Fixes to MAP progress bar --- pymc/tuning/starting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index c2a7ddf2021..08364fa5ddf 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -177,7 +177,7 @@ def find_MAP( if isinstance(e, StopIteration): pm._log.info(e) finally: - cost_func.progress.update(cost_func.task, total=cost_func.n_eval) + cost_func.progress.update(cost_func.task, completed=cost_func.n_eval) print(file=sys.stdout) mx0 = RaveledVars(mx0, x0.point_map_info) @@ -238,7 +238,7 @@ def __call__(self, x): raise StopIteration self.n_eval += 1 - self.progress.advance(self.task, 1) + self.progress.update(self.task, completed=self.n_eval) if self.use_gradient: return value, grad From 41603c39668aca108b8db0456f0b49b2086350c4 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 2 Apr 2024 09:52:11 -0500 Subject: [PATCH 07/11] Customize progress bar theme --- pymc/sampling/forward.py | 10 +++++++++- pymc/sampling/mcmc.py | 18 +++++++++++------- pymc/sampling/parallel.py | 28 +++++++++++++++------------- pymc/smc/sampling.py | 12 ++++++++++-- pymc/tuning/starting.py | 12 ++++++++++-- pymc/variational/inference.py | 12 +++++++++++- 6 files changed, 66 insertions(+), 26 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 910a8707f0b..f9b6b7ea4eb 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -45,7 +45,9 @@ RandomStateSharedVariable, ) from pytensor.tensor.sharedvar import SharedVariable +from rich.console import Console from rich.progress import Progress +from rich.theme import Theme from typing_extensions import TypeAlias import pymc as pm @@ -70,6 +72,12 @@ "sample_posterior_predictive", ) +custom_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) ArrayLike: TypeAlias = Union[np.ndarray, list[float]] PointList: TypeAlias = list[PointType] @@ -830,7 +838,7 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress() as progress: + with Progress(console=Console(theme=custom_theme)) as progress: for idx in progress.track(np.arange(samples), description="Sampling ..."): if nchain > 1: # the trace object will either be a MultiTrace (and have _straces)... diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 18f724e8e33..1fb660194f5 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -35,7 +35,9 @@ from arviz import InferenceData, dict_to_dataset from arviz.data.base import make_attrs from pytensor.graph.basic import Variable +from rich.console import Console from rich.progress import Progress +from rich.theme import Theme from typing_extensions import Protocol, TypeAlias import pymc as pm @@ -80,6 +82,13 @@ Step: TypeAlias = Union[BlockedStep, CompoundStep] +custom_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) + class SamplingIteratorCallback(Protocol): """Signature of the callable that may be passed to `pm.sample(callable=...)`.""" @@ -1026,14 +1035,9 @@ def _sample( ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - # if progressbar: - # sampling = progress_bar(sampling_gen, total=draws, display=progressbar) - # sampling.comment = _desc.format(**_pbar_data) - # else: - # sampling = sampling_gen - with Progress() as progress: + with Progress(console=Console(theme=custom_theme)) as progress: try: - task = progress.add_task(_desc.format(**_pbar_data), total=draws) + task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar) for it, diverging in enumerate(sampling_gen): if it >= skip_first and diverging: _pbar_data["divergences"] += 1 diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index a2919a086c2..3f54f9d077f 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -26,7 +26,9 @@ import cloudpickle import numpy as np -from rich import progress +from rich.console import Console +from rich.progress import BarColumn, Progress, TimeRemainingColumn +from rich.theme import Theme from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError @@ -34,6 +36,13 @@ logger = logging.getLogger(__name__) +custom_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) + class ParallelSamplingError(Exception): def __init__(self, message, chain): @@ -420,11 +429,12 @@ def __init__( self._in_context = False - self._progress = progress.Progress( + self._progress = Progress( "[progress.description]{task.description}", - progress.BarColumn(), + BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", - progress.TimeRemainingColumn(), + TimeRemainingColumn(), + console=Console(theme=custom_theme), ) self._show_progress = progressbar self._divergences = 0 @@ -432,9 +442,6 @@ def __init__( self._total_draws = chains * (draws + tune) self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences" self._chains = chains - # if progressbar: - # self._progress = progress_bar(range(chains * (draws + tune)), display=progressbar) - # self._progress.comment = self._desc.format(self) def _make_active(self): while self._inactive and len(self._active) < self._max_active: @@ -456,12 +463,6 @@ def __iter__(self): visible=self._show_progress, ) - # if self._active and self._progress: - # self._progress.update(self._total_draws) - # progress.update( - # task, divergences=self._divergences - # ) - while self._active: draw = ProcessAdapter.recv_draw(self._active) proc, is_last, draw, tuning, stats = draw @@ -480,6 +481,7 @@ def __iter__(self): self._active.remove(proc) self._finished.append(proc) self._make_active() + progress.update(task, description=self._desc.format(self), refresh=True) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 366b068e8a4..234a01ff355 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -25,9 +25,9 @@ import numpy as np from arviz import InferenceData - -# from fastprogress.fastprogress import force_console_behavior, progress_bar +from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn +from rich.theme import Theme import pymc @@ -39,6 +39,13 @@ from pymc.stats.convergence import log_warnings, run_convergence_checks from pymc.util import RandomState, _get_seeds_per_chain +custom_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) + def sample_smc( draws=2000, @@ -370,6 +377,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): SpinnerColumn(), TimeElapsedColumn(), TextColumn("{task.fields[status]}"), + console=Console(theme=custom_theme), ) as progress: futures = [] # keep track of the jobs with multiprocessing.Manager() as manager: diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 08364fa5ddf..a6799fe2d4f 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -27,11 +27,11 @@ import numpy as np import pytensor.gradient as tg - -# from fastprogress.fastprogress import ProgressBar, progress_bar from numpy import isfinite from pytensor import Variable +from rich.console import Console from rich.progress import Progress, TextColumn +from rich.theme import Theme from scipy.optimize import minimize import pymc as pm @@ -44,6 +44,13 @@ __all__ = ["find_MAP"] +custom_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) + def find_MAP( start=None, @@ -214,6 +221,7 @@ def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=No self.progress = Progress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), + console=Console(theme=custom_theme), ) self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index dff402a86e4..d594373e371 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -18,7 +18,9 @@ import numpy as np +from rich.console import Console from rich.progress import Progress, TextColumn, track +from rich.theme import Theme import pymc as pm @@ -41,6 +43,13 @@ State = collections.namedtuple("State", "i,step,callbacks,score") +custom_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) + class Inference: r"""**Base class for Variational Inference** @@ -150,7 +159,7 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): def _iterate_without_loss(self, s, n, step_func, progressbar, callbacks): i = 0 try: - with Progress() as progress: + with Progress(console=Console(theme=custom_theme)) as progress: task = progress.add_task("Fitting", total=n, visible=progressbar) for i in range(n): step_func() @@ -202,6 +211,7 @@ def _infmean(input_array): with Progress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), + console=Console(theme=custom_theme), ) as progress: task = progress.add_task("Fitting:", total=n, visible=progressbar, loss="") for i in range(n): From c4e8fca26e4ff0af556b4870fbd34a6384a4953e Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Tue, 2 Apr 2024 13:28:24 -0500 Subject: [PATCH 08/11] Added progressbar_theme argument --- pymc/sampling/forward.py | 10 ++++++--- pymc/sampling/mcmc.py | 15 +++++++++++-- pymc/sampling/parallel.py | 5 +++-- pymc/smc/sampling.py | 10 --------- pymc/tuning/starting.py | 20 ++++++++++++----- pymc/variational/inference.py | 42 ++++++++++++++++++++++++++--------- 6 files changed, 69 insertions(+), 33 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index f9b6b7ea4eb..7e8558378b9 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -72,7 +72,7 @@ "sample_posterior_predictive", ) -custom_theme = Theme( +default_theme = Theme( { "bar.complete": "#1764f4", "bar.finished": "green", @@ -450,6 +450,7 @@ def sample_posterior_predictive( sample_dims: Optional[list[str]] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = None, return_inferencedata: bool = True, extend_inferencedata: bool = False, predictions: bool = False, @@ -838,8 +839,9 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress(console=Console(theme=custom_theme)) as progress: - for idx in progress.track(np.arange(samples), description="Sampling ..."): + with Progress(console=Console(theme=progressbar_theme or default_theme)) as progress: + task = progress.add_task("Sampling ...", total=samples, visible=progressbar) + for idx in np.arange(samples): if nchain > 1: # the trace object will either be a MultiTrace (and have _straces)... if hasattr(_trace, "_straces"): @@ -859,6 +861,8 @@ def sample_posterior_predictive( for k, v in zip(vars_, values): ppc_trace_t.insert(k.name, v, idx) + progress.advance(task) + except KeyboardInterrupt: pass diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 1fb660194f5..1d683229379 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -82,7 +82,7 @@ Step: TypeAlias = Union[BlockedStep, CompoundStep] -custom_theme = Theme( +default_theme = Theme( { "bar.complete": "#1764f4", "bar.finished": "green", @@ -386,6 +386,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = None, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -415,6 +416,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = None, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -444,6 +446,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, + progressbar_theme: Optional[Theme] = None, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -770,6 +773,7 @@ def sample( "tune": tune, "var_names": var_names, "progressbar": progressbar, + "progressbar_theme": progressbar_theme, "model": model, "cores": cores, "callback": callback, @@ -992,6 +996,7 @@ def _sample( trace: IBaseTrace, tune: int, model: Optional[Model] = None, + progressbar_theme: Optional[Theme] = None, callback=None, **kwargs, ) -> None: @@ -1019,6 +1024,8 @@ def _sample( tune : int Number of iterations to tune. model : Model (optional if in ``with`` context) + progressbar_theme : Theme + Optional custom theme for the progress bar. """ skip_first = kwargs.get("skip_first", 0) @@ -1035,7 +1042,7 @@ def _sample( ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - with Progress(console=Console(theme=custom_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme or default_theme)) as progress: try: task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar) for it, diverging in enumerate(sampling_gen): @@ -1137,6 +1144,7 @@ def _mp_sample( random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, + progressbar_theme: Optional[Theme] = None, traces: Sequence[IBaseTrace], model: Optional[Model] = None, callback: Optional[SamplingIteratorCallback] = None, @@ -1164,6 +1172,8 @@ def _mp_sample( Dicts must contain numeric (transformed) initial values for all (transformed) free variables. progressbar : bool Whether or not to display a progress bar in the command line. + progressbar_theme : Theme + Optional custom theme for the progress bar. traces Recording backends for each chain. model : Model (optional if in ``with`` context) @@ -1188,6 +1198,7 @@ def _mp_sample( start_points=start, step_method=step, progressbar=progressbar, + progressbar_theme=progressbar_theme, mp_ctx=mp_ctx, ) try: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 3f54f9d077f..6e3300658f0 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) -custom_theme = Theme( +default_theme = Theme( { "bar.complete": "#1764f4", "bar.finished": "green", @@ -384,6 +384,7 @@ def __init__( start_points: Sequence[dict[str, np.ndarray]], step_method, progressbar: bool = True, + progressbar_theme: Theme = default_theme, mp_ctx=None, ): if any(len(arg) != chains for arg in [seeds, start_points]): @@ -434,7 +435,7 @@ def __init__( BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), - console=Console(theme=custom_theme), + console=Console(theme=progressbar_theme or default_theme), ) self._show_progress = progressbar self._divergences = 0 diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 234a01ff355..e5129e8fcee 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -25,9 +25,7 @@ import numpy as np from arviz import InferenceData -from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn -from rich.theme import Theme import pymc @@ -39,13 +37,6 @@ from pymc.stats.convergence import log_warnings, run_convergence_checks from pymc.util import RandomState, _get_seeds_per_chain -custom_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - } -) - def sample_smc( draws=2000, @@ -377,7 +368,6 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): SpinnerColumn(), TimeElapsedColumn(), TextColumn("{task.fields[status]}"), - console=Console(theme=custom_theme), ) as progress: futures = [] # keep track of the jobs with multiprocessing.Manager() as manager: diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index a6799fe2d4f..09f13aa0d6c 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -44,7 +44,7 @@ __all__ = ["find_MAP"] -custom_theme = Theme( +default_theme = Theme( { "bar.complete": "#1764f4", "bar.finished": "green", @@ -59,6 +59,7 @@ def find_MAP( return_raw=False, include_transformed=True, progressbar=True, + progressbar_theme=None, maxeval=5000, model=None, *args, @@ -91,6 +92,8 @@ def find_MAP( to the constrained values progressbar: bool, optional defaults to True Whether to display a progress bar in the command line. + progressbar_theme: Theme, optional + Custom theme for the progress bar. maxeval: int, optional, defaults to 5000 The maximum number of times the posterior distribution is evaluated. model: Model (optional if in `with` context) @@ -168,9 +171,9 @@ def find_MAP( method = "Powell" if compute_gradient and method != "Powell": - cost_func = CostFuncWrapper(maxeval, progressbar, logp_func, dlogp_func) + cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func, dlogp_func) else: - cost_func = CostFuncWrapper(maxeval, progressbar, logp_func) + cost_func = CostFuncWrapper(maxeval, progressbar, progressbar_theme, logp_func) compute_gradient = False with cost_func.progress: @@ -205,7 +208,14 @@ def allfinite(x): class CostFuncWrapper: - def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=None): + def __init__( + self, + maxeval=5000, + progressbar=True, + progressbar_theme=None, + logp_func=None, + dlogp_func=None, + ): self.n_eval = 0 self.maxeval = maxeval self.logp_func = logp_func @@ -221,7 +231,7 @@ def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=No self.progress = Progress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), - console=Console(theme=custom_theme), + console=Console(theme=progressbar_theme or default_theme), ) self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index d594373e371..4535409a993 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -43,7 +43,7 @@ State = collections.namedtuple("State", "i,step,callbacks,score") -custom_theme = Theme( +default_theme = Theme( { "bar.complete": "#1764f4", "bar.finished": "green", @@ -99,7 +99,15 @@ def run_profiling(self, n=1000, score=None, **kwargs): pass return step_func.profile - def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): + def fit( + self, + n=10000, + score=None, + callbacks=None, + progressbar=True, + progressbar_theme=None, + **kwargs, + ): """Perform Operator Variational Inference Parameters @@ -112,6 +120,8 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): calls provided functions after each iteration step progressbar : bool whether to show progressbar or not + progressbar_theme : Theme + Custom theme for the progress bar Other Parameters ---------------- @@ -146,9 +156,13 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): step_func = self.objective.step_function(score=score, **kwargs) if score: - state = self._iterate_with_loss(0, n, step_func, progressbar, callbacks) + state = self._iterate_with_loss( + 0, n, step_func, progressbar, progressbar_theme, callbacks + ) else: - state = self._iterate_without_loss(0, n, step_func, progressbar, callbacks) + state = self._iterate_without_loss( + 0, n, step_func, progressbar, progressbar_theme, callbacks + ) # hack to allow pm.fit() access to loss hist self.approx.hist = self.hist @@ -156,10 +170,10 @@ def fit(self, n=10000, score=None, callbacks=None, progressbar=True, **kwargs): return self.approx - def _iterate_without_loss(self, s, n, step_func, progressbar, callbacks): + def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): i = 0 try: - with Progress(console=Console(theme=custom_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme or default_theme)) as progress: task = progress.add_task("Fitting", total=n, visible=progressbar) for i in range(n): step_func() @@ -195,7 +209,7 @@ def _iterate_without_loss(self, s, n, step_func, progressbar, callbacks): logger.info(str(e)) return State(i + s, step=step_func, callbacks=callbacks, score=False) - def _iterate_with_loss(self, s, n, step_func, progressbar, callbacks): + def _iterate_with_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): def _infmean(input_array): """Return the mean of the finite values of the array""" input_array = input_array[np.isfinite(input_array)].astype("float64") @@ -211,7 +225,7 @@ def _infmean(input_array): with Progress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), - console=Console(theme=custom_theme), + console=Console(theme=progressbar_theme or default_theme), ) as progress: task = progress.add_task("Fitting:", total=n, visible=progressbar, loss="") for i in range(n): @@ -274,15 +288,17 @@ def _infmean(input_array): self.hist = np.concatenate([self.hist, scores]) return State(i + s, step=step_func, callbacks=callbacks, score=True) - def refine(self, n, progressbar=True): + def refine(self, n, progressbar=True, progressbar_theme=None): """Refine the solution using the last compiled step function""" if self.state is None: raise TypeError("Need to call `.fit` first") i, step, callbacks, score = self.state if score: - state = self._iterate_with_loss(i, n, step, progressbar, callbacks) + state = self._iterate_with_loss(i, n, step, progressbar, progressbar_theme, callbacks) else: - state = self._iterate_without_loss(i, n, step, progressbar, callbacks) + state = self._iterate_without_loss( + i, n, step, progressbar, progressbar_theme, callbacks + ) self.state = state @@ -639,6 +655,7 @@ def fit( score=None, callbacks=None, progressbar=True, + progressbar_theme=None, obj_n_mc=500, **kwargs, ): @@ -647,6 +664,7 @@ def fit( score=score, callbacks=callbacks, progressbar=progressbar, + progressbar_theme=progressbar_theme, obj_n_mc=obj_n_mc, **kwargs, ) @@ -697,6 +715,8 @@ def fit( calls provided functions after each iteration step progressbar: bool whether to show progressbar or not + progressbar_theme: Theme + Custom theme for the progress bar obj_n_mc: `int` Number of monte carlo samples used for approximation of objective gradients tf_n_mc: `int` From eefa29d5fd2aae922743b0c63485bb06835ded13 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 3 Apr 2024 07:55:21 -0500 Subject: [PATCH 09/11] Moved default progressbar theme to util --- pymc/sampling/forward.py | 12 +++--------- pymc/sampling/mcmc.py | 20 +++++++------------- pymc/sampling/parallel.py | 13 +++---------- pymc/tuning/starting.py | 16 ++++------------ pymc/util.py | 8 ++++++++ pymc/variational/inference.py | 19 ++++++------------- 6 files changed, 31 insertions(+), 57 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 7e8558378b9..d43c587a473 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -61,6 +61,7 @@ RandomState, _get_seeds_per_chain, dataset_to_point_list, + default_progress_theme, get_default_varnames, point_wrapper, ) @@ -72,13 +73,6 @@ "sample_posterior_predictive", ) -default_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - } -) - ArrayLike: TypeAlias = Union[np.ndarray, list[float]] PointList: TypeAlias = list[PointType] @@ -450,7 +444,7 @@ def sample_posterior_predictive( sample_dims: Optional[list[str]] = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = None, + progressbar_theme: Optional[Theme] = default_progress_theme, return_inferencedata: bool = True, extend_inferencedata: bool = False, predictions: bool = False, @@ -839,7 +833,7 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress(console=Console(theme=progressbar_theme or default_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme)) as progress: task = progress.add_task("Sampling ...", total=samples, visible=progressbar) for idx in np.arange(samples): if nchain > 1: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 1d683229379..1241b10b865 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -67,6 +67,7 @@ RandomSeed, RandomState, _get_seeds_per_chain, + default_progress_theme, drop_warning_stat, get_untransformed_name, is_transformed_name, @@ -82,13 +83,6 @@ Step: TypeAlias = Union[BlockedStep, CompoundStep] -default_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - } -) - class SamplingIteratorCallback(Protocol): """Signature of the callable that may be passed to `pm.sample(callable=...)`.""" @@ -386,7 +380,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = None, + progressbar_theme: Optional[Theme] = default_progress_theme, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -416,7 +410,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = None, + progressbar_theme: Optional[Theme] = default_progress_theme, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -446,7 +440,7 @@ def sample( cores: Optional[int] = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = None, + progressbar_theme: Optional[Theme] = default_progress_theme, step=None, var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", @@ -996,7 +990,7 @@ def _sample( trace: IBaseTrace, tune: int, model: Optional[Model] = None, - progressbar_theme: Optional[Theme] = None, + progressbar_theme: Optional[Theme] = default_progress_theme, callback=None, **kwargs, ) -> None: @@ -1042,7 +1036,7 @@ def _sample( ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - with Progress(console=Console(theme=progressbar_theme or default_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme)) as progress: try: task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar) for it, diverging in enumerate(sampling_gen): @@ -1144,7 +1138,7 @@ def _mp_sample( random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, - progressbar_theme: Optional[Theme] = None, + progressbar_theme: Optional[Theme] = default_progress_theme, traces: Sequence[IBaseTrace], model: Optional[Model] = None, callback: Optional[SamplingIteratorCallback] = None, diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 6e3300658f0..caf865c78eb 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -32,17 +32,10 @@ from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError -from pymc.util import RandomSeed +from pymc.util import RandomSeed, default_progress_theme logger = logging.getLogger(__name__) -default_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - } -) - class ParallelSamplingError(Exception): def __init__(self, message, chain): @@ -384,7 +377,7 @@ def __init__( start_points: Sequence[dict[str, np.ndarray]], step_method, progressbar: bool = True, - progressbar_theme: Theme = default_theme, + progressbar_theme: Theme = default_progress_theme, mp_ctx=None, ): if any(len(arg) != chains for arg in [seeds, start_points]): @@ -435,7 +428,7 @@ def __init__( BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), - console=Console(theme=progressbar_theme or default_theme), + console=Console(theme=progressbar_theme), ) self._show_progress = progressbar self._divergences = 0 diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 09f13aa0d6c..90f56d19bee 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -31,7 +31,6 @@ from pytensor import Variable from rich.console import Console from rich.progress import Progress, TextColumn -from rich.theme import Theme from scipy.optimize import minimize import pymc as pm @@ -39,18 +38,11 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext -from pymc.util import get_default_varnames, get_value_vars_from_user_vars +from pymc.util import default_progress_theme, get_default_varnames, get_value_vars_from_user_vars from pymc.vartypes import discrete_types, typefilter __all__ = ["find_MAP"] -default_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - } -) - def find_MAP( start=None, @@ -59,7 +51,7 @@ def find_MAP( return_raw=False, include_transformed=True, progressbar=True, - progressbar_theme=None, + progressbar_theme=default_progress_theme, maxeval=5000, model=None, *args, @@ -212,7 +204,7 @@ def __init__( self, maxeval=5000, progressbar=True, - progressbar_theme=None, + progressbar_theme=default_progress_theme, logp_func=None, dlogp_func=None, ): @@ -231,7 +223,7 @@ def __init__( self.progress = Progress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), - console=Console(theme=progressbar_theme or default_theme), + console=Console(theme=progressbar_theme), ) self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") diff --git a/pymc/util.py b/pymc/util.py index 8388a8ed49b..b72e17e0ae5 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -27,11 +27,19 @@ from pytensor import Variable from pytensor.compile import SharedVariable from pytensor.graph.utils import ValidatingScratchpad +from rich.theme import Theme from pymc.exceptions import BlockModelAccessError VarName = NewType("VarName", str) +default_progress_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + } +) + class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 4535409a993..3d9e6fd8eae 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -20,10 +20,10 @@ from rich.console import Console from rich.progress import Progress, TextColumn, track -from rich.theme import Theme import pymc as pm +from pymc.util import default_progress_theme from pymc.variational import test_functions from pymc.variational.approximations import Empirical, FullRank, MeanField from pymc.variational.operators import KL, KSD @@ -43,13 +43,6 @@ State = collections.namedtuple("State", "i,step,callbacks,score") -default_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - } -) - class Inference: r"""**Base class for Variational Inference** @@ -105,7 +98,7 @@ def fit( score=None, callbacks=None, progressbar=True, - progressbar_theme=None, + progressbar_theme=default_progress_theme, **kwargs, ): """Perform Operator Variational Inference @@ -173,7 +166,7 @@ def fit( def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): i = 0 try: - with Progress(console=Console(theme=progressbar_theme or default_theme)) as progress: + with Progress(console=Console(theme=progressbar_theme)) as progress: task = progress.add_task("Fitting", total=n, visible=progressbar) for i in range(n): step_func() @@ -225,7 +218,7 @@ def _infmean(input_array): with Progress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), - console=Console(theme=progressbar_theme or default_theme), + console=Console(theme=progressbar_theme), ) as progress: task = progress.add_task("Fitting:", total=n, visible=progressbar, loss="") for i in range(n): @@ -288,7 +281,7 @@ def _infmean(input_array): self.hist = np.concatenate([self.hist, scores]) return State(i + s, step=step_func, callbacks=callbacks, score=True) - def refine(self, n, progressbar=True, progressbar_theme=None): + def refine(self, n, progressbar=True, progressbar_theme=default_progress_theme): """Refine the solution using the last compiled step function""" if self.state is None: raise TypeError("Need to call `.fit` first") @@ -655,7 +648,7 @@ def fit( score=None, callbacks=None, progressbar=True, - progressbar_theme=None, + progressbar_theme=default_progress_theme, obj_n_mc=500, **kwargs, ): From 3f107e94f619339f42e36e2b34bf6f13ef038411 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 3 Apr 2024 08:08:11 -0500 Subject: [PATCH 10/11] Convert compute_log_density to use Progress instead of track --- pymc/stats/log_density.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index 1435560437b..daf172342f4 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -15,14 +15,15 @@ from typing import Optional, cast from arviz import InferenceData, dict_to_dataset -from rich.progress import track +from rich.console import Console +from rich.progress import Progress import pymc from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata from pymc.model import Model, modelcontext from pymc.pytensorf import PointFunc -from pymc.util import dataset_to_point_list +from pymc.util import dataset_to_point_list, default_progress_theme __all__ = ("compute_log_likelihood", "compute_log_prior") @@ -169,15 +170,14 @@ def compute_log_density( n_pts = len(posterior_pts) logdens_dict = _DefaultTrace(n_pts) - if progressbar: - indices = track(range(n_pts), description="Computing log density") - else: - indices = range(n_pts) - for idx in indices: - logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) - for rv_name, rv_logdens in zip(var_names, logdenss_pts): - logdens_dict.insert(rv_name, rv_logdens, idx) + with Progress(console=Console(theme=default_progress_theme)) as progress: + task = progress.add_task("Computing log density...", total=n_pts, visible=progressbar) + for idx in range(n_pts): + logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) + for rv_name, rv_logdens in zip(var_names, logdenss_pts): + logdens_dict.insert(rv_name, rv_logdens, idx) + progress.update(task, advance=1) logdens_trace = logdens_dict.trace_dict for key, array in logdens_trace.items(): From f490b796a83f3a57300c69b8cf66af5c2bd57665 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 3 Apr 2024 12:02:11 -0500 Subject: [PATCH 11/11] Getting rid of mypy complaint --- pymc/sampling/parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index caf865c78eb..29ccc1a0d0e 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -22,6 +22,7 @@ from collections import namedtuple from collections.abc import Sequence +from typing import Optional import cloudpickle import numpy as np @@ -377,7 +378,7 @@ def __init__( start_points: Sequence[dict[str, np.ndarray]], step_method, progressbar: bool = True, - progressbar_theme: Theme = default_progress_theme, + progressbar_theme: Optional[Theme] = default_progress_theme, mp_ctx=None, ): if any(len(arg) != chains for arg in [seeds, start_points]):