-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interactive plot #46
Interactive plot #46
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,6 +153,25 @@ This is implemented into the SWIFT-Emulator with | |
sweep as `ModelValues` and `ModelParameters` | ||
containers, that are easy to parse. | ||
|
||
Interactive plots | ||
----------------- | ||
|
||
Another way to explore the effect of varying the parameters is | ||
to try an interactive plot. Every emulator object contains an | ||
`interactive_plot` method. This generates a plot with a slider | ||
for each parameter. The plot will update to show the emulator | ||
predictions when sliders are adjusted. | ||
|
||
.. code-block:: python | ||
|
||
schecter_emulator.interactive_plot(predict_x, xlabel="Stellar mass", ylabel="dn/dlogM") | ||
|
||
.. image:: interactive_plot.png | ||
|
||
It is possible to pass reference data to be plotted when calling | ||
:meth:`swiftemulator.emulators.base.BaseEmulator.interactive\_plot`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you could add this to the example? Similar to the previous comment I would at least mention the reference line and/or the extra datapoints |
||
|
||
|
||
Model Parameters Features | ||
------------------------- | ||
|
||
|
@@ -234,4 +253,4 @@ This method is a lot slower than the default hyperparameter | |
optimisation, and may take some time to compute. The main | ||
take away from plots like this is to see whether the | ||
hyperparameters are converged, and whether they are | ||
consistent with the faster optimisation method. | ||
consistent with the faster optimisation method. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,7 +85,7 @@ def predict_values( | |
Parameters | ||
---------- | ||
|
||
independent, np.array | ||
independent: np.array | ||
Independent continuous variables to evaluate the emulator | ||
at. If the emulator is discrete, these are only allowed to be | ||
the discrete independent variables that the emulator was trained at | ||
|
@@ -98,12 +98,12 @@ def predict_values( | |
Returns | ||
------- | ||
|
||
dependent_predictions, np.array | ||
dependent_predictions: np.array | ||
Array of predictions, if the emulator is a function f, these | ||
are the predicted values of f(independent) evaluted at the position | ||
of the input ``model_parameters``. | ||
|
||
dependent_prediction_errors, np.array | ||
dependent_prediction_errors: np.array | ||
Errors on the model predictions. For models where the errors are | ||
unconstrained, this is an array of zeroes. | ||
|
||
|
@@ -123,3 +123,90 @@ def predict_values( | |
) | ||
|
||
raise NotImplementedError | ||
|
||
def interactive_plot( | ||
self, | ||
x: np.array, | ||
xlabel: str = "", | ||
ylabel: str = "", | ||
x_data: np.array = None, | ||
y_data: np.array = None, | ||
): | ||
""" | ||
Generates an interactive plot which displays the emulator predictions. | ||
If no reference data is passed to be overplotted then the plot will | ||
display a line which corresponds to the predictions for the mean | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it the mean or the midpoint of the ranges? (also out of curiosity, I guess it's the same for uniform points) |
||
of the parameter values. | ||
|
||
Parameters | ||
---------- | ||
|
||
x: np.array | ||
Array of data for which the emulator should make predictions. | ||
|
||
xlabel: str, optional | ||
Label for horizontal axis on the resultant figure. | ||
|
||
ylabel: str, optional | ||
Label for vertical axis on the resultant figure. | ||
|
||
x_data: np.array, optional | ||
Array containing x-values of reference data to plot. | ||
|
||
y_data: np.array, optional | ||
Array containing y-values of reference data to plot. | ||
Must be the same shape as x_data | ||
""" | ||
import matplotlib.pyplot as plt | ||
from matplotlib.widgets import Slider | ||
|
||
fig, ax = plt.subplots() | ||
model_specification = self.model_specification | ||
param_means = {} | ||
sliders = [] | ||
n_param = model_specification.number_of_parameters | ||
fig.subplots_adjust(bottom=0.12 + n_param * 0.1) | ||
for i in range(n_param): | ||
# Extracting information needed for slider | ||
name = model_specification.parameter_names[i] | ||
lo_lim = sorted(model_specification.parameter_limits[i])[0] | ||
hi_lim = sorted(model_specification.parameter_limits[i])[1] | ||
param_means[name] = (lo_lim + hi_lim) / 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe refer to it as midpoints as that is more correct (also clarifies my earlier comment) |
||
|
||
# Adding slider | ||
if model_specification.parameter_printable_names: | ||
name = model_specification.parameter_printable_names[i] | ||
slider_ax = fig.add_axes([0.35, i * 0.1, 0.3, 0.1]) | ||
slider = Slider( | ||
ax=slider_ax, | ||
label=name, | ||
valmin=lo_lim, | ||
valmax=hi_lim, | ||
valinit=(lo_lim + hi_lim) / 2, | ||
) | ||
sliders.append(slider) | ||
|
||
# Setting up initial value | ||
pred, pred_var = self.predict_values(x, param_means) | ||
if (x_data is None) or (y_data is None): | ||
ax.plot(x, pred, "k--") | ||
else: | ||
ax.plot(x_data, y_data, "k.") | ||
(line,) = ax.plot(x, pred) | ||
ax.set_xlabel(xlabel) | ||
ax.set_ylabel(ylabel) | ||
|
||
# Define and enable update function | ||
def update(val): | ||
params = { | ||
model_specification.parameter_names[i]: sliders[i].val | ||
for i in range(n_param) | ||
} | ||
pred, pred_var = self.predict_values(x, params) | ||
line.set_ydata(pred) | ||
|
||
for slider in sliders: | ||
slider.on_changed(update) | ||
|
||
plt.show() | ||
plt.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a line that explains the default reference line.