diff --git a/docs/source/emulator_analysis/index.rst b/docs/source/emulator_analysis/index.rst index 0c36324..a93e7b4 100644 --- a/docs/source/emulator_analysis/index.rst +++ b/docs/source/emulator_analysis/index.rst @@ -153,6 +153,31 @@ This is implemented into the SWIFT-Emulator with sweep as `ModelValues` and `ModelParameters` containers, that are easy to parse. +Interactive plots +----------------- + +Another way to explore the effect of varying the parameters is +to try an interactive plot. Every emulator object contains an +`interactive_plot` method. This generates a plot with a slider +for each parameter. The plot will update to show the emulator +predictions when sliders are adjusted. The emulator will make +its initial prediction using the parameter values passed to it. +If no parameters are passed if will default to the midpoint of +each parameter range. It is also possible to pass reference data +to overplot on the emulator predictions. If no reference data is +passed the plot will display a fixed dashed line corresponding to +the prediction using the initial parameter values. + +.. code-block:: python + + schecter_emulator.interactive_plot(predict_x, initial_params=center, + xlabel="Stellar mass", ylabel="dn/dlogM", + x_data=[10.5, 11, 11.5], + y_data=[-10, -11, -12]) + +.. image:: interactive_plot.png + + Model Parameters Features ------------------------- @@ -234,4 +259,4 @@ This method is a lot slower than the default hyperparameter optimisation, and may take some time to compute. The main take away from plots like this is to see whether the hyperparameters are converged, and whether they are -consistent with the faster optimisation method. \ No newline at end of file +consistent with the faster optimisation method. diff --git a/docs/source/emulator_analysis/interactive_plot.png b/docs/source/emulator_analysis/interactive_plot.png new file mode 100644 index 0000000..5e20c48 Binary files /dev/null and b/docs/source/emulator_analysis/interactive_plot.png differ diff --git a/swiftemulator/emulators/base.py b/swiftemulator/emulators/base.py index df2cbbc..276f242 100644 --- a/swiftemulator/emulators/base.py +++ b/swiftemulator/emulators/base.py @@ -85,7 +85,7 @@ def predict_values( Parameters ---------- - independent, np.array + independent: np.array Independent continuous variables to evaluate the emulator at. If the emulator is discrete, these are only allowed to be the discrete independent variables that the emulator was trained at @@ -98,12 +98,12 @@ def predict_values( Returns ------- - dependent_predictions, np.array + dependent_predictions: np.array Array of predictions, if the emulator is a function f, these are the predicted values of f(independent) evaluted at the position of the input ``model_parameters``. - dependent_prediction_errors, np.array + dependent_prediction_errors: np.array Errors on the model predictions. For models where the errors are unconstrained, this is an array of zeroes. @@ -123,3 +123,98 @@ def predict_values( ) raise NotImplementedError + + def interactive_plot( + self, + x: np.array, + initial_params: Dict[str, float] = {}, + xlabel: str = "", + ylabel: str = "", + x_data: np.array = None, + y_data: np.array = None, + ): + """ + Generates an interactive plot which displays the emulator predictions. + If initial_params should contain the initial parameter values to make a + prediction for. If initial_params is not passed the midpoint of each of + the parameter values will be used instead. If no reference data is + passed to be overplotted then the plot will display a line which + corresponds to the predictions for the initial parameter values. + + Parameters + ---------- + + x: np.array + Array of data for which the emulator should make predictions. + + initial_params: Dict[str, float], optional + What parameters values to plot the predicition for initally. + If missing the midpoint of each parameter range will be used. + + xlabel: str, optional + Label for horizontal axis on the resultant figure. + + ylabel: str, optional + Label for vertical axis on the resultant figure. + + x_data: np.array, optional + Array containing x-values of reference data to plot. + + y_data: np.array, optional + Array containing y-values of reference data to plot. + Must be the same shape as x_data + """ + import matplotlib.pyplot as plt + from matplotlib.widgets import Slider + + fig, ax = plt.subplots() + model_specification = self.model_specification + sliders = [] + n_param = model_specification.number_of_parameters + fig.subplots_adjust(bottom=0.12 + n_param * 0.1) + for i in range(n_param): + # Extracting information needed for slider + name = model_specification.parameter_names[i] + lo_lim = sorted(model_specification.parameter_limits[i])[0] + hi_lim = sorted(model_specification.parameter_limits[i])[1] + if not name in initial_params: + initial_params[name] = (lo_lim + hi_lim) / 2 + + # Adding slider + printable_name = name + if model_specification.parameter_printable_names: + printable_name = model_specification.parameter_printable_names[i] + slider_ax = fig.add_axes([0.35, i * 0.1, 0.3, 0.1]) + slider = Slider( + ax=slider_ax, + label=printable_name, + valmin=lo_lim, + valmax=hi_lim, + valinit=initial_params[name], + ) + sliders.append(slider) + + # Plotting lines and reference data + pred, pred_var = self.predict_values(x, initial_params) + if (x_data is None) or (y_data is None): + ax.plot(x, pred, "k--") + else: + ax.plot(x_data, y_data, "k.") + (line,) = ax.plot(x, pred) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + # Define and enable update function + def update(val): + params = { + model_specification.parameter_names[i]: sliders[i].val + for i in range(n_param) + } + pred, pred_var = self.predict_values(x, params) + line.set_ydata(pred) + + for slider in sliders: + slider.on_changed(update) + + plt.show() + plt.close() diff --git a/swiftemulator/emulators/multi_gaussian_process.py b/swiftemulator/emulators/multi_gaussian_process.py index 2b5f45e..3802a4f 100644 --- a/swiftemulator/emulators/multi_gaussian_process.py +++ b/swiftemulator/emulators/multi_gaussian_process.py @@ -231,12 +231,16 @@ def predict_values( for index, (low, high) in enumerate(self.independent_regions): mask = np.logical_and( - independent > low - if low is not None - else np.ones_like(independent).astype(bool), - independent < high - if high is not None - else np.ones_like(independent).astype(bool), + ( + independent > low + if low is not None + else np.ones_like(independent).astype(bool) + ), + ( + independent < high + if high is not None + else np.ones_like(independent).astype(bool) + ), ) predicted, errors = self.emulators[index].predict_values( diff --git a/swiftemulator/io/swift.py b/swiftemulator/io/swift.py index e9b3247..8429525 100644 --- a/swiftemulator/io/swift.py +++ b/swiftemulator/io/swift.py @@ -103,8 +103,8 @@ def load_pipeline_outputs( "adaptive_mass_function", "histogram", ] - recursive_search = ( - lambda d, k: d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None + recursive_search = lambda d, k: ( + d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None ) line_search = lambda d: recursive_search(d, line_types)