Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interactive plot #46

Merged
merged 5 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion docs/source/emulator_analysis/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,25 @@ This is implemented into the SWIFT-Emulator with
sweep as `ModelValues` and `ModelParameters`
containers, that are easy to parse.

Interactive plots
-----------------

Another way to explore the effect of varying the parameters is
to try an interactive plot. Every emulator object contains an
`interactive_plot` method. This generates a plot with a slider
for each parameter. The plot will update to show the emulator
predictions when sliders are adjusted.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a line that explains the default reference line.


.. code-block:: python

schecter_emulator.interactive_plot(predict_x, xlabel="Stellar mass", ylabel="dn/dlogM")

.. image:: interactive_plot.png

It is possible to pass reference data to be plotted when calling
:meth:`swiftemulator.emulators.base.BaseEmulator.interactive\_plot`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could add this to the example? Similar to the previous comment I would at least mention the reference line and/or the extra datapoints



Model Parameters Features
-------------------------

Expand Down Expand Up @@ -234,4 +253,4 @@ This method is a lot slower than the default hyperparameter
optimisation, and may take some time to compute. The main
take away from plots like this is to see whether the
hyperparameters are converged, and whether they are
consistent with the faster optimisation method.
consistent with the faster optimisation method.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 90 additions & 3 deletions swiftemulator/emulators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def predict_values(
Parameters
----------

independent, np.array
independent: np.array
Independent continuous variables to evaluate the emulator
at. If the emulator is discrete, these are only allowed to be
the discrete independent variables that the emulator was trained at
Expand All @@ -98,12 +98,12 @@ def predict_values(
Returns
-------

dependent_predictions, np.array
dependent_predictions: np.array
Array of predictions, if the emulator is a function f, these
are the predicted values of f(independent) evaluted at the position
of the input ``model_parameters``.

dependent_prediction_errors, np.array
dependent_prediction_errors: np.array
Errors on the model predictions. For models where the errors are
unconstrained, this is an array of zeroes.

Expand All @@ -123,3 +123,90 @@ def predict_values(
)

raise NotImplementedError

def interactive_plot(
self,
x: np.array,
xlabel: str = "",
ylabel: str = "",
x_data: np.array = None,
y_data: np.array = None,
):
"""
Generates an interactive plot which displays the emulator predictions.
If no reference data is passed to be overplotted then the plot will
display a line which corresponds to the predictions for the mean
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the mean or the midpoint of the ranges? (also out of curiosity, I guess it's the same for uniform points)

of the parameter values.

Parameters
----------

x: np.array
Array of data for which the emulator should make predictions.

xlabel: str, optional
Label for horizontal axis on the resultant figure.

ylabel: str, optional
Label for vertical axis on the resultant figure.

x_data: np.array, optional
Array containing x-values of reference data to plot.

y_data: np.array, optional
Array containing y-values of reference data to plot.
Must be the same shape as x_data
"""
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

fig, ax = plt.subplots()
model_specification = self.model_specification
param_means = {}
sliders = []
n_param = model_specification.number_of_parameters
fig.subplots_adjust(bottom=0.12 + n_param * 0.1)
for i in range(n_param):
# Extracting information needed for slider
name = model_specification.parameter_names[i]
lo_lim = sorted(model_specification.parameter_limits[i])[0]
hi_lim = sorted(model_specification.parameter_limits[i])[1]
param_means[name] = (lo_lim + hi_lim) / 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe refer to it as midpoints as that is more correct (also clarifies my earlier comment)


# Adding slider
if model_specification.parameter_printable_names:
name = model_specification.parameter_printable_names[i]
slider_ax = fig.add_axes([0.35, i * 0.1, 0.3, 0.1])
slider = Slider(
ax=slider_ax,
label=name,
valmin=lo_lim,
valmax=hi_lim,
valinit=(lo_lim + hi_lim) / 2,
)
sliders.append(slider)

# Setting up initial value
pred, pred_var = self.predict_values(x, param_means)
if (x_data is None) or (y_data is None):
ax.plot(x, pred, "k--")
else:
ax.plot(x_data, y_data, "k.")
(line,) = ax.plot(x, pred)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

# Define and enable update function
def update(val):
params = {
model_specification.parameter_names[i]: sliders[i].val
for i in range(n_param)
}
pred, pred_var = self.predict_values(x, params)
line.set_ydata(pred)

for slider in sliders:
slider.on_changed(update)

plt.show()
plt.close()
16 changes: 10 additions & 6 deletions swiftemulator/emulators/multi_gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,16 @@ def predict_values(

for index, (low, high) in enumerate(self.independent_regions):
mask = np.logical_and(
independent > low
if low is not None
else np.ones_like(independent).astype(bool),
independent < high
if high is not None
else np.ones_like(independent).astype(bool),
(
independent > low
if low is not None
else np.ones_like(independent).astype(bool)
),
(
independent < high
if high is not None
else np.ones_like(independent).astype(bool)
),
)

predicted, errors = self.emulators[index].predict_values(
Expand Down
4 changes: 2 additions & 2 deletions swiftemulator/io/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def load_pipeline_outputs(
"adaptive_mass_function",
"histogram",
]
recursive_search = (
lambda d, k: d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
recursive_search = lambda d, k: (
d.get(k[0], recursive_search(d, k[1:])) if len(k) > 0 else None
)
line_search = lambda d: recursive_search(d, line_types)

Expand Down
Loading