Skip to content

Commit

Permalink
Merge pull request #33 from raymondEhlers/dev
Browse files Browse the repository at this point in the history
Last updates for QM
  • Loading branch information
raymondEhlers authored Sep 14, 2023
2 parents 6758b6b + 45b7927 commit e6c8206
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 37 deletions.
54 changes: 29 additions & 25 deletions config/jet_substructure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ plot:
mcmc: True
qhat: True
closure_tests: False
across_analyses: True

debug_level: 0

Expand Down Expand Up @@ -213,31 +214,31 @@ analyses:
########################
# Main analyses for QM23
########################
#analysis_jet:
# parameters:
# preprocessing:
# <<: *default_preprocessing_parameters
# emulators:
# default_group:
# <<: *default_emulator_parameters
# n_pc: 5
# observable_list:
# - 'jet__pt_'
# observable_exclude_list:
# - "pt_y_atlas"
# - "2760__PbPb__inclusive_chjet__pt_alice"
# mcmc:
# <<: *default_mcmc_parameters
# n_walkers: 200
# n_burn_steps: 1000
# n_sampling_steps: 25000
# closure:
# <<: *default_closure_parameters
# <<: *model_base
# cuts:
# 'chjet__pt_star__R0.2': [14, 100]
# 'chjet__pt_star__R0.4': [16, 100]
# plot_panel_shapes: [[3,3], [3,3], [3,3]]
analysis_jet:
parameters:
preprocessing:
<<: *default_preprocessing_parameters
emulators:
default_group:
<<: *default_emulator_parameters
n_pc: 5
observable_list:
- 'jet__pt_'
observable_exclude_list:
- "pt_y_atlas"
- "2760__PbPb__inclusive_chjet__pt_alice"
mcmc:
<<: *default_mcmc_parameters
n_walkers: 200
n_burn_steps: 1000
n_sampling_steps: 25000
closure:
<<: *default_closure_parameters
<<: *model_base
cuts:
'chjet__pt_star__R0.2': [14, 100]
'chjet__pt_star__R0.4': [16, 100]
plot_panel_shapes: [[3,3], [3,3], [3,3]]

analysis_jet_substructure_n_walkers_100_long_prod:
parameters:
Expand Down Expand Up @@ -280,6 +281,9 @@ analyses:






#analysis_hadron10_1000:
# <<: *analysis_hadron
# cuts:
Expand Down
217 changes: 217 additions & 0 deletions src/bayesian_inference/plot_analyses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
"""Plot across analyses
authors: J. Mulligan, R. Ehlers
"""

from __future__ import annotations
from typing import Any

import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_context('paper', rc={'font.size':18,'axes.titlesize':18,'axes.labelsize':18})

from bayesian_inference import data_IO, mcmc
from bayesian_inference import plot_qhat

logger = logging.getLogger(__name__)


def plot(analyses: dict[str, Any], config_file: str, output_dir: str) -> None:
"""Plot across selected analyses
:param dict[str, MCMCConfig] configs: dictionary of MCMCConfig objects, with keys corresponding to analysis names
:return None: we save plots to disk
"""
# Setup
configs = {}
for analysis_name, analysis_config in analyses.items():
for parameterization in analysis_config['parameterizations']:
configs[f"{analysis_name}_{parameterization}"] = mcmc.MCMCConfig(
analysis_name=analysis_name,
parameterization=parameterization,
analysis_config=analysis_config,
config_file=config_file
)

# Validation and setup
results = {}
posteriors = {}
for analysis_name, config in configs.items():
# Check if mcmc.h5 file exists
if not os.path.exists(config.mcmc_outputfile):
logger.info(f'MCMC output does not exist: {config.mcmc_outputfile}')
return

# Get results from file
results[analysis_name] = data_IO.read_dict_from_h5(config.output_dir, config.mcmc_outputfilename, verbose=True)
n_walkers, n_steps, n_params = results[analysis_name]['chain'].shape
posteriors[analysis_name] = results[analysis_name]['chain'].reshape((n_walkers*n_steps, n_params))

# Plot output dir
plot_dir = os.path.join(output_dir, 'plot_analyses')
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)

plot_qhat_across_analyses(
results=results,
posteriors=posteriors,
configs=configs,
plot_dir=plot_dir,
E=100,
cred_level=0.9,
n_samples=5000,
plot_mean=False,
)


#---------------------------------------------------------------[]
def plot_qhat_across_analyses(
results,
posteriors,
plot_dir,
configs, E=0, T=0, cred_level=0., n_samples=5000, n_x=50,
plot_prior=True, plot_mean=True, plot_map=False, target_design_point=np.array([])):
'''
Plot qhat credible interval from posterior samples,
as a function of either E or T (with the other held fixed)
Pretty much copied from plot_qhat, hastily modified due to QM.
:param 2darray posterior: posterior samples -- shape (n_walkers*n_steps, n_params)
:param float E: fix jet energy (GeV), and plot as a function of T
:param float T: fix temperature (GeV), and plot as a function of E
:param float cred_level: credible interval level
:param int n_samples: number of posterior samples to use for plotting
:param int n_x: number of T or E points to plot
:param 1darray target_design_point: if closure test, design point corresponding to "truth" qhat value
'''

colors = [
sns.xkcd_rgb['light blue'],
"#FF8301", # orange,
]
already_drawn_prior_credible_interval = False

fig, ax = plt.subplots()
for color, (analysis_name, result), posterior, config in zip(colors, results.items(), posteriors.values(), configs.values()):
# TODO: Labels are hard coded...
analysis_label = "Jet $R_{\mathrm{AA}}$"
if "substructure" in analysis_name:
analysis_label = "Jet $R_{\mathrm{AA}}$ + substructure"
# Sample posterior parameters without replacement
if posterior.shape[0] < n_samples:
n_samples = posterior.shape[0]
logger.warning(f'Not enough posterior samples to plot {n_samples} samples, using {n_samples} instead')
idx = np.random.choice(posterior.shape[0], size=n_samples, replace=False)
posterior_samples = posterior[idx,:]

# Compute qhat for each sample (as well as MAP value), as a function of T or E
# qhat_posteriors will be a 2d array of shape (x_array.size, n_samples)
if E:
xlabel = 'T (GeV)'
suffix = f'E{E}'
label = f'E = {E} GeV'
x_array = np.linspace(0.16, 0.5, n_x)
qhat_posteriors = np.array([plot_qhat.qhat(posterior_samples, config, T=T, E=E) for T in x_array])
elif T:
xlabel = 'E (GeV)'
suffix = f'T{T}'
label = f'T = {T} GeV'
x_array = np.linspace(5, 200, n_x)
qhat_posteriors = np.array([plot_qhat.qhat(posterior_samples, config, T=T, E=E) for E in x_array])

# Plot mean qhat values for each T or E
qhat_mean = np.mean(qhat_posteriors, axis=1)
if plot_mean:
ax.plot(x_array, qhat_mean, color=color, #sns.xkcd_rgb['denim blue'],
linewidth=2., linestyle='--', label=f'{analysis_label}: Mean')

# Plot the MAP value as well for each T or E
if plot_map:
if E:
qhat_map = np.array([plot_qhat.qhat(mcmc.map_parameters(posterior_samples), config, T=T, E=E) for T in x_array])
elif T:
qhat_map = np.array([plot_qhat.qhat(mcmc.map_parameters(posterior_samples), config, T=T, E=E) for E in x_array])
ax.plot(x_array, qhat_map, #sns.xkcd_rgb['medium green'],
linewidth=2., linestyle='--', label=f'{analysis_label}: MAP')

# Plot prior as well, for comparison
# TODO: one could also plot some type of "information gain" metric, e.g. KL divergence
if plot_prior and not already_drawn_prior_credible_interval:

# Generate samples
prior_samples = plot_qhat._generate_prior_samples(config, n_samples=n_samples)

# Compute qhat for each sample, as a function of T or E
if E:
qhat_priors = np.array([plot_qhat.qhat(prior_samples, config, T=T, E=E) for T in x_array])
elif T:
qhat_priors = np.array([plot_qhat.qhat(prior_samples, config, T=T, E=E) for E in x_array])

# Get credible interval for each T or E
h_prior = [mcmc.credible_interval(qhat_values, confidence=cred_level) for qhat_values in qhat_priors]
credible_low_prior = [i[0] for i in h_prior]
credible_up_prior = [i[1] for i in h_prior]
ax.fill_between(x_array, credible_low_prior, credible_up_prior, color=color, #color=sns.xkcd_rgb['light blue'],
alpha=0.3, label=f'Prior {int(cred_level*100)}% Credible Interval (CI)')
already_drawn_prior_credible_interval = True

# Get credible interval for each T or E
h = [mcmc.credible_interval(qhat_values, confidence=cred_level) for qhat_values in qhat_posteriors]
credible_low = [i[0] for i in h]
credible_up = [i[1] for i in h]
ax.fill_between(x_array, credible_low, credible_up, color=color, #alpha=0.8 if color == "#FF8301" else 1.0, #color=sns.xkcd_rgb['light blue'],
label=f'{analysis_label}: Posterior {int(cred_level*100)}% CI')


# If closure test: Plot truth qhat value
# We will return a dict of info needed for plotting closure plots, including a
# boolean array (as a fcn of T or E) of whether the truth value is contained within credible region
if target_design_point.any():
if E:
qhat_truth = [plot_qhat.qhat(target_design_point, config, T=T, E=E) for T in x_array]
elif T:
qhat_truth = [plot_qhat.qhat(target_design_point, config, T=T, E=E) for E in x_array]
ax.plot(x_array, qhat_truth, sns.xkcd_rgb['pale red'],
linewidth=2., label='Target')

qhat_closure = {}
qhat_closure['qhat_closure_array'] = np.array([((qhat_truth[i] < credible_up[i]) and (qhat_truth[i] > credible_low[i])) for i,_ in enumerate(x_array)]).squeeze()
qhat_closure['qhat_mean'] = qhat_mean
qhat_closure['x_array'] = x_array
qhat_closure['cred_level'] = cred_level

# Plot formatting
ax.set_xlabel(xlabel)
ax.set_ylabel(r'$\hat{q}/T^3$')
ymin = 0
if plot_map:
ymax = 2*max(qhat_map)
else:
# Use mean in all other cases
ymax = 2*max(qhat_mean)
ax.set_ylim([ymin, ymax])
ax.legend(title=f'{label}', title_fontsize=12,
loc='upper right', fontsize=12, frameon=False)

# Add preliminary label
# TODO: Remove this...
#ax.text(0.05, 0.05,
# 'JETSCAPE Preliminary',
# horizontalalignment="left",
# verticalalignment="bottom",
# multialignment="left",
# transform=ax.transAxes
# )

fig.tight_layout()
fig.savefig(f'{plot_dir}/qhat_{suffix}.pdf')
plt.close('all')

if target_design_point.any():
return qhat_closure
22 changes: 11 additions & 11 deletions src/bayesian_inference/plot_qhat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def plot(config):
_plot_observable_sensitivity(posterior, plot_dir, config, delta=0.1, n_samples=1000)

#---------------------------------------------------------------[]
def plot_qhat(posterior, plot_dir, config, E=0, T=0, cred_level=0., n_samples=5000, n_x=50,
def plot_qhat(posterior, plot_dir, config, E=0, T=0, cred_level=0., n_samples=5000, n_x=50,
plot_prior=True, plot_mean=True, plot_map=False, target_design_point=np.array([])):
'''
Plot qhat credible interval from posterior samples,
Expand Down Expand Up @@ -155,7 +155,7 @@ def plot_qhat(posterior, plot_dir, config, E=0, T=0, cred_level=0., n_samples=50
ymin = 0
if plot_mean:
ymax = 2*max(qhat_mean)
elif plot_map:
elif plot_map:
ymax = 2*max(qhat_map)
axes = plt.gca()
axes.set_ylim([ymin, ymax])
Expand All @@ -177,7 +177,7 @@ def _plot_observable_sensitivity(posterior, plot_dir, config, delta=0.1, n_sampl
Note: this is just a normalized partial derivative, dO_j/dx_i * (x_i/O_j)
Based on:
Based on:
- https://arxiv.org/abs/2011.01430
- https://link.springer.com/article/10.1007/BF00547132
'''
Expand All @@ -187,17 +187,17 @@ def _plot_observable_sensitivity(posterior, plot_dir, config, delta=0.1, n_sampl

# Plot sensitivity index for each parameter
for i_parameter in range(posterior.shape[1]):
_plot_single_parameter_observable_sensitivity(map_parameters, i_parameter,
_plot_single_parameter_observable_sensitivity(map_parameters, i_parameter,
plot_dir, config, delta=delta)

# TODO: Plot sensitivity for qhat:
# S(qhat, O_j, delta) = 1/delta * [O_j(qhat_map') - O_j(qhat_map)] / O_j(qhat)
# In the current qhat formulation, qhat = qhat(x_0=alpha_s_fix) only depends on x_0=alpha_s_fix.
# So this information is already captured in the x_0 sensitivity plot above.
# If we want to explicitly compute S(qhat), we need to evaluate the emulator at qhat_map'=(1+delta)*qhat_map.
# In principle one should find the x_0 corresponding to (1+delta)*qhat_map.
# For simplicity we can just evaluate x_0'=x_0(1+delta) and then redefine delta=qhat(x_0')-qhat(x_0) -- but
# this is excatly the same as the S(x_0) plot above, up the redefinition of delta.
# this is exactly the same as the S(x_0) plot above, up the redefinition of delta.
# It may nevertheless be nice to add since a plot of S(qhat) will likely be more salient to viewers.

#---------------------------------------------------------------
Expand All @@ -218,7 +218,7 @@ def _plot_single_parameter_observable_sensitivity(map_parameters, i_parameter, p
x_prime[i_parameter] = (1+delta)*x_prime[i_parameter]
x = np.expand_dims(x, axis=0)
x_prime = np.expand_dims(x_prime, axis=0)

# Get emulator predictions at the two points
emulation_config = emulation.EmulationConfig.from_config_file(
analysis_name=config.analysis_name,
Expand All @@ -232,9 +232,9 @@ def _plot_single_parameter_observable_sensitivity(map_parameters, i_parameter, p

# Convert to dict: emulator_predictions[observable_label]
observables = data_IO.read_dict_from_h5(config.output_dir, 'observables.h5', verbose=False)
emulator_predictions_x_dict = data_IO.observable_dict_from_matrix(emulator_predictions_x['central_value'],
emulator_predictions_x_dict = data_IO.observable_dict_from_matrix(emulator_predictions_x['central_value'],
observables, observable_filter=emulation_config.observable_filter)
emulator_predictions_x_prime_dict = data_IO.observable_dict_from_matrix(emulator_predictions_x_prime['central_value'],
emulator_predictions_x_prime_dict = data_IO.observable_dict_from_matrix(emulator_predictions_x_prime['central_value'],
observables, observable_filter=emulation_config.observable_filter)

# Construct dict of sensitivity index, in same format as emulator_predictions['central_value']
Expand All @@ -244,7 +244,7 @@ def _plot_single_parameter_observable_sensitivity(map_parameters, i_parameter, p
x = emulator_predictions_x_dict['central_value'][observable_label]
x_prime = emulator_predictions_x_prime_dict['central_value'][observable_label]
sensitivity_index_dict[observable_label] = 1/delta * (x_prime - x) / x

# Plot
plot_list = [sensitivity_index_dict]
columns = [0]
Expand All @@ -254,7 +254,7 @@ def _plot_single_parameter_observable_sensitivity(map_parameters, i_parameter, p
ylabel = rf'$S({param}, \mathcal{{O}}, \delta)$'
#ylabel = rf'$S({param}, \mathcal{{O}}, \delta) = \frac{{1}}{{\delta}} \frac{{\mathcal{{O}}([1+\delta] {param})-\mathcal{{O}}({param})}}{{\mathcal{{O}}({param})}}$'
filename = f'sensitivity_index_{i_parameter}.pdf'
plot_utils.plot_observable_panels(plot_list, labels, colors, columns, config, plot_dir, filename,
plot_utils.plot_observable_panels(plot_list, labels, colors, columns, config, plot_dir, filename,
linewidth=1, ymin=-5, ymax=5, ylabel=ylabel, plot_exp_data=False, bar_plot=True)

#---------------------------------------------------------------
Expand Down
6 changes: 5 additions & 1 deletion src/bayesian_inference/steer_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pathlib import Path

from bayesian_inference import data_IO, preprocess_input_data, emulation, mcmc
from bayesian_inference import plot_input_data, plot_emulation, plot_mcmc, plot_qhat, plot_closure
from bayesian_inference import plot_input_data, plot_emulation, plot_mcmc, plot_qhat, plot_closure, plot_analyses

from bayesian_inference import common_base, helpers

Expand Down Expand Up @@ -258,6 +258,10 @@ def run_analysis(self):
logger.info("")

# Plots across multiple analyses
if self.plot['across_analyses']:
# NOTE: This is a departure from the standard API, but we need a convention for how
# to pass multiple analyses, so we'll just go with it for now.
plot_analyses.plot(self.analyses, self.config_file, self.output_dir)


####################################################################################################################
Expand Down

0 comments on commit e6c8206

Please sign in to comment.