diff --git a/zeus/ensemble.py b/zeus/ensemble.py index 4242d87..1fd258e 100644 --- a/zeus/ensemble.py +++ b/zeus/ensemble.py @@ -62,19 +62,25 @@ def __init__(self, check_walkers=True, shuffle_ensemble=True, light_mode=False, + logger=None, ): # Set up logger - self.logger = logging.getLogger() - for handler in self.logger.handlers[:]: - self.logger.removeHandler(handler) - handler = logging.StreamHandler() - self.logger.addHandler(handler) - if verbose: - self.logger.setLevel(logging.INFO) + if logger is None: + self.logger = logging.getLogger() + for handler in self.logger.handlers[:]: + self.logger.removeHandler(handler) + handler = logging.StreamHandler() + self.logger.addHandler(handler) + if verbose: + self.logger.setLevel(logging.INFO) + else: + self.logger.setLevel(logging.WARNING) + elif isinstance(logger, logging.Logger): + self.logger = logger else: - self.logger.setLevel(logging.WARNING) - + raise ValueError("logger should be an instance of logging.Logger or None") + # Parse the move schedule if moves is None: self._moves = [DifferentialMove()] @@ -724,4 +730,4 @@ def sample(self, class sampler(EnsembleSampler): def __init__(self, *args, **kwargs): logging.warning('The sampler class has been deprecated. Please use the new EnsembleSampler class.') - super().__init__(*args, **kwargs) \ No newline at end of file + super().__init__(*args, **kwargs)