diff --git a/appletree/context.py b/appletree/context.py index 6a52bec..fbcf221 100644 --- a/appletree/context.py +++ b/appletree/context.py @@ -272,17 +272,36 @@ def continue_fitting(self, context=None, iteration=500, batch_size=1_000_000): self._dump_meta(batch_size=batch_size) return result - def get_post_parameters(self): - """Get parameters correspondes to max posterior.""" - logp = self.sampler.get_log_prob(flat=True) - chain = self.sampler.get_chain(flat=True) - mpe_parameters = chain[np.argmax(logp)] - mpe_parameters = emcee.ensemble.ndarray_to_list_of_dicts( - [mpe_parameters], + def get_post_parameters(self, which="mpe"): + """Get parameters from the backend. + + Args: + which: str, 'mpe', 'random' or 'median'. 'mpe' is the maximum posterior estimate, + i.e. the parameter set with the highest posterior value. 'random' returns a + random parameter set from the posterior distribution. 'median' is the marginal medians. + + """ + # Assign attributes for the first time + # This speeds up if the user wanna call this function many times + if not hasattr(self, "_logp"): + self._logp = self.sampler.get_log_prob(flat=True) + if not hasattr(self, "_chain"): + self._chain = self.sampler.get_chain(flat=True) + if which == "mpe": + _parameters = self._chain[np.argmax(self._logp)] + elif which == "random": + _parameters = self._chain[np.random.randint(len(self._logp))] + elif which == "median": + _parameters = np.median(self._chain, axis=0) + else: + raise ValueError(f"which should be 'mpe', 'random' or 'median', got {which}!") + + _parameters = emcee.ensemble.ndarray_to_list_of_dicts( + [_parameters], self.sampler.parameter_names, )[0] parameters = copy.deepcopy(self.par_manager.get_all_parameter()) - parameters.update(mpe_parameters) + parameters.update(_parameters) return parameters def get_all_post_parameters(self, **kwargs):