Skip to content

Commit

Permalink
Refactor progress bars (#1272)
Browse files Browse the repository at this point in the history
tqdm progress bars are used in a couple of places. They don't play so well in
non-interactive jobs, testing, ... . They also cause trouble with nbspinx
(#1246, #1271).
Progress bars can be disabled for specific tasks, but not globally (or at
least not very conveniently).
Since recently, tqdm can be controlled via environment variables (e.g.,
disabling all progress bars or changing update frequency). However,
this works by changing the argument defaults, so it only works if we don't
pass explicit `disable=...`. Therefore, this PR introduces some wrapper that
checks whether the user explicitly enabled/disabled progress bars. If not,
we go with the tqdm default, which means showing all progress bars unless
globally disabled. An additional `enabled` argument is added for convenience.

---------

Co-authored-by: Dilan Pathirana <[email protected]>
  • Loading branch information
dweindl and dilpath authored Jan 9, 2024
1 parent 6366ecf commit e389257
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 30 deletions.
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
# Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
os.environ["PYDEVD_DISABLE_FILE_VALIDATION"] = "1"

# TQDM and nbsphinx do not play well together. Therefore, disable TQDM
# for the documentation build.
# (`Content block expected for the "raw" directive; none found.`)
os.environ["TQDM_DISABLE"] = "1"

# -- General configuration ------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.
Expand Down
4 changes: 2 additions & 2 deletions pypesto/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self):

@abc.abstractmethod
def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Execute tasks.
Expand All @@ -22,6 +22,6 @@ def execute(
tasks:
List of tasks to execute.
progress_bar:
Whether to display a progress bar. Defaults to ``True``.
Whether to display a progress bar.
"""
raise NotImplementedError("This engine is not intended to be called.")
8 changes: 4 additions & 4 deletions pypesto/engine/mpi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import cloudpickle as pickle
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand All @@ -32,7 +32,7 @@ def __init__(self):
super().__init__()

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""
Pickle tasks and distribute work to workers.
Expand All @@ -42,7 +42,7 @@ def execute(
tasks:
List of :class:`pypesto.engine.Task` to execute.
progress_bar:
Whether to display a progress bar. Defaults to ``True``.
Whether to display a progress bar.
Returns
-------
Expand All @@ -55,6 +55,6 @@ def execute(

with MPIPoolExecutor() as executor:
results = executor.map(
work, tqdm(pickled_tasks, disable=not progress_bar)
work, tqdm(pickled_tasks, enable=progress_bar)
)
return results
8 changes: 4 additions & 4 deletions pypesto/engine/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Any, Union

import cloudpickle as pickle
from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
self.method: str = method

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Pickle tasks and distribute work over parallel processes.
Expand All @@ -61,7 +61,7 @@ def execute(
tasks:
List of :class:`pypesto.engine.Task` to execute.
progress_bar:
Whether to display a progress bar. Defaults to ``True``.
Whether to display a progress bar.
Returns
-------
Expand All @@ -81,7 +81,7 @@ def execute(
tqdm(
pool.imap(work, pickled_tasks),
total=len(pickled_tasks),
disable=not progress_bar,
enable=progress_bar,
),
)

Expand Down
7 changes: 3 additions & 4 deletions pypesto/engine/multi_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Union

from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(self, n_threads: Union[int, None] = None):
self.n_threads: int = n_threads

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Deepcopy tasks and distribute work over parallel threads.
Expand All @@ -70,7 +69,7 @@ def execute(
tqdm(
pool.map(work, copied_tasks),
total=len(copied_tasks),
disable=not progress_bar,
enable=progress_bar,
),
)

Expand Down
10 changes: 6 additions & 4 deletions pypesto/engine/single_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Engines without parallelization."""
from typing import Any

from tqdm import tqdm

from ..util import tqdm
from .base import Engine
from .task import Task

Expand All @@ -18,7 +17,7 @@ def __init__(self):
super().__init__()

def execute(
self, tasks: list[Task], progress_bar: bool = True
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Execute all tasks in a simple for loop sequentially.
Expand All @@ -34,7 +33,10 @@ def execute(
A list of results.
"""
results = []
for task in tqdm(tasks, disable=not progress_bar):
for task in tqdm(
tasks,
enable=progress_bar,
):
results.append(task.execute())

return results
2 changes: 1 addition & 1 deletion pypesto/ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def predict(
include_llh_weights: bool = False,
include_sigmay: bool = False,
engine: Engine = None,
progress_bar: bool = True,
progress_bar: bool = None,
) -> EnsemblePrediction:
"""
Run predictions for a full ensemble.
Expand Down
2 changes: 1 addition & 1 deletion pypesto/optimize/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def minimize(
startpoint_method: Union[StartpointMethod, Callable, bool] = None,
result: Result = None,
engine: Engine = None,
progress_bar: bool = True,
progress_bar: bool = None,
options: OptimizeOptions = None,
history_options: HistoryOptions = None,
filename: Union[str, Callable, None] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypesto/profile/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parameter_profile(
result_index: int = 0,
next_guess_method: Union[Callable, str] = 'adaptive_step_regression',
profile_options: ProfileOptions = None,
progress_bar: bool = True,
progress_bar: bool = None,
filename: Union[str, Callable, None] = None,
overwrite: bool = False,
) -> Result:
Expand Down
2 changes: 1 addition & 1 deletion pypesto/sample/adaptive_metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def default_options(cls):
# target acceptance rate
'target_acceptance_rate': 0.234,
# show progress
'show_progress': True,
'show_progress': None,
}

def initialize(self, problem: Problem, x0: np.ndarray):
Expand Down
8 changes: 4 additions & 4 deletions pypesto/sample/metropolis.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Dict, Sequence, Union

import numpy as np
from tqdm import tqdm

from ..history import NoHistory
from ..objective import NegLogPriors, ObjectiveBase
from ..problem import Problem
from ..result import McmcPtResult
from ..util import tqdm
from .sampler import InternalSample, InternalSampler


Expand Down Expand Up @@ -51,7 +51,7 @@ def default_options(cls):
"""Return the default options for the sampler."""
return {
'std': 1.0, # the proposal standard deviation
'show_progress': True, # whether to show the progress
'show_progress': None, # whether to show the progress
}

def initialize(self, problem: Problem, x0: np.ndarray):
Expand All @@ -73,10 +73,10 @@ def sample(self, n_samples: int, beta: float = 1.0):
lpost = -self.trace_neglogpost[-1]
lprior = -self.trace_neglogprior[-1]

show_progress = self.options['show_progress']
show_progress = self.options.get('show_progress', None)

# loop over iterations
for _ in tqdm(range(int(n_samples)), disable=not show_progress):
for _ in tqdm(range(int(n_samples)), enable=show_progress):
# perform step
x, lpost, lprior = self._perform_step(
x=x, lpost=lpost, lprior=lprior, beta=beta
Expand Down
8 changes: 4 additions & 4 deletions pypesto/sample/parallel_tempering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Dict, List, Sequence, Union

import numpy as np
from tqdm import tqdm

from ..problem import Problem
from ..result import McmcPtResult
from ..util import tqdm
from .sampler import InternalSampler, Sampler


Expand Down Expand Up @@ -70,7 +70,7 @@ def default_options(cls) -> Dict:
'max_temp': 5e4,
'exponent': 4,
'temper_log_posterior': False,
'show_progress': True,
'show_progress': None,
}

def initialize(
Expand All @@ -89,9 +89,9 @@ def initialize(

def sample(self, n_samples: int, beta: float = 1.0):
"""Sample and swap in between samplers."""
show_progress = self.options['show_progress']
show_progress = self.options.get('show_progress', None)
# loop over iterations
for i_sample in tqdm(range(int(n_samples)), disable=not show_progress):
for i_sample in tqdm(range(int(n_samples)), enable=show_progress):
# TODO test
# sample
for sampler, beta in zip(self.samplers, self.betas):
Expand Down
38 changes: 38 additions & 0 deletions pypesto/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
from scipy import cluster
from tqdm import tqdm as _tqdm


def _check_none(fun: Callable[..., Any]) -> Callable[..., Union[Any, None]]:
Expand Down Expand Up @@ -295,3 +296,40 @@ def delete_nan_inf(
)
)
return x, fvals[finite_fvals]


def tqdm(*args, enable: bool = None, **kwargs):
"""
Create a progress bar using tqdm.
Parameters
----------
args:
Arguments passed to tqdm.
enable:
Whether to enable the progress bar.
If None, use tqdm defaults.
Mutually exclusive with `disable`.
kwargs:
Keyword arguments passed to tqdm.
Returns
-------
progress_bar:
A progress bar.
"""
# Drop the `disable` argument unless it is not-None.
# This way, we don't interfere with TQDM_DISABLE or other global
# tqdm settings.
disable = kwargs.pop("disable", None)

if enable is not None:
if disable is not None and enable != disable:
raise ValueError(
"Contradicting values for `enable` and `disable` passed."
)
disable = not enable

if disable is not None:
kwargs["disable"] = disable
return _tqdm(*args, **kwargs)

0 comments on commit e389257

Please sign in to comment.