Skip to content

Commit

Permalink
Fixed some code quality issues from sonar
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobychev committed May 15, 2024
1 parent 60e6db3 commit d6d4139
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 124 deletions.
64 changes: 64 additions & 0 deletions src/ctapipe/irf/vis_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import numpy as np
import scipy.stats as st


def find_columnwise_stats(table, col_bins, percentiles, density=False):
tab = np.squeeze(table)
out = np.ones((tab.shape[1], 5)) * -1
# This loop over the columns seems unavoidable,
# so having a reasonable number of bins in that
# direction is good
for idx, col in enumerate(tab.T):
if (col > 0).sum() == 0:
continue
col_est = st.rv_histogram((col, col_bins), density=density)
out[idx, 0] = col_est.mean()
out[idx, 1] = col_est.median()
out[idx, 2] = col_est.std()
out[idx, 3] = col_est.ppf(percentiles[0])
out[idx, 4] = col_est.ppf(percentiles[1])
return out


def rebin_x_2d_hist(hist, xbins, x_cent, num_bins_merge=3):
num_y, num_x = hist.shape
if (num_x) % num_bins_merge == 0:
rebin_x = xbins[::num_bins_merge]
rebin_xcent = x_cent.reshape(-1, num_bins_merge).mean(axis=1)
rebin_hist = hist.reshape(num_y, -1, num_bins_merge).sum(axis=2)
return rebin_x, rebin_xcent, rebin_hist
else:
raise ValueError(
f"Could not merge {num_bins_merge} along axis of dimension {num_x}"
)


def get_2d_hist_from_table(x_prefix, y_prefix, table, column):
x_lo_name, x_hi_name = f"{x_prefix}_LO", f"{x_prefix}_HI"
y_lo_name, y_hi_name = f"{y_prefix}_LO", f"{y_prefix}_HI"

xbins = np.hstack((table[x_lo_name][0], table[x_hi_name][0][-1]))
ybins = np.hstack((table[y_lo_name][0], table[y_hi_name][0][-1]))

if isinstance(column, str):
mat_vals = np.squeeze(table[column])
else:
mat_vals = column

return mat_vals, xbins, ybins


def get_bin_centers(bins):
return np.convolve(bins, kernel=[0.5, 0.5], mode="valid")


def get_x_bin_values_with_rebinning(num_rebin, xbins, xcent, mat_vals, density):
if num_rebin > 1:
rebin_x, rebin_xcent, rebin_hist = rebin_x_2d_hist(
mat_vals, xbins, xcent, num_bins_merge=num_rebin
)
density = False
else:
rebin_x, rebin_xcent, rebin_hist = xbins, xcent, mat_vals

return rebin_x, rebin_xcent, rebin_hist, density
162 changes: 38 additions & 124 deletions src/ctapipe/irf/visualisation.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,37 @@
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
from astropy.visualization import quantity_support
from matplotlib.colors import LogNorm
from pyirf.binning import join_bin_lo_hi

from .vis_utils import (
find_columnwise_stats,
get_2d_hist_from_table,
get_bin_centers,
get_x_bin_values_with_rebinning,
)

quantity_support()


def plot_2D_irf_table(
def plot_2d_irf_table(
ax, table, column, x_prefix, y_prefix, x_label=None, y_label=None, **mpl_args
):
x_lo_name, x_hi_name = f"{x_prefix}_LO", f"{x_prefix}_HI"
y_lo_name, y_hi_name = f"{y_prefix}_LO", f"{y_prefix}_HI"

xbins = np.hstack((table[x_lo_name][0], table[x_hi_name][0][-1]))
mat_vals, xbins, ybins = get_2d_hist_from_table(x_prefix, y_prefix, table, column)

ybins = np.hstack((table[y_lo_name][0], table[y_hi_name][0][-1]))
if not x_label:
x_label = x_prefix
if not y_label:
y_label = y_prefix
if isinstance(column, str):
mat_vals = np.squeeze(table[column])
else:
mat_vals = column

plot = plot_hist2D(
plot = plot_hist2d(
ax, mat_vals, xbins, ybins, xlabel=x_label, ylabel=y_label, **mpl_args
)
plt.colorbar(plot)
return ax


def rebin_x_2D_hist(hist, xbins, x_cent, num_bins_merge=3):
num_y, num_x = hist.shape
if (num_x) % num_bins_merge == 0:
rebin_x = xbins[::num_bins_merge]
rebin_xcent = x_cent.reshape(-1, num_bins_merge).mean(axis=1)
rebin_hist = hist.reshape(num_y, -1, num_bins_merge).sum(axis=2)
return rebin_x, rebin_xcent, rebin_hist
else:
raise ValueError(
f"Could not merge {num_bins_merge} along axis of dimension {num_x}"
)


def find_columnwise_stats(table, col_bins, percentiles, density=False):
tab = np.squeeze(table)
out = np.ones((tab.shape[1], 5)) * -1
# This loop over the columns seems unavoidable,
# so having a reasonable number of bins in that
# direction is good
for idx, col in enumerate(tab.T):
if (col > 0).sum() == 0:
continue
col_est = st.rv_histogram((col, col_bins), density=density)
out[idx, 0] = col_est.mean()
out[idx, 1] = col_est.median()
out[idx, 2] = col_est.std()
out[idx, 3] = col_est.ppf(percentiles[0])
out[idx, 4] = col_est.ppf(percentiles[1])
return out


def plot_2D_table_with_col_stats(
def plot_2d_table_with_col_stats(
ax,
table,
column,
Expand All @@ -88,35 +54,14 @@ def plot_2D_table_with_col_stats(
1 -> median + standard deviation
2 -> median + user specified quantiles around median (default 0.1 to 0.9)
"""
x_lo_name, x_hi_name = f"{x_prefix}_LO", f"{x_prefix}_HI"
y_lo_name, y_hi_name = f"{y_prefix}_LO", f"{y_prefix}_HI"

xbins = np.hstack((table[x_lo_name][0], table[x_hi_name][0][-1]))

ybins = np.hstack((table[y_lo_name][0], table[y_hi_name][0][-1]))

xcent = np.convolve(
[0.5, 0.5], np.hstack((table[x_lo_name][0], table[x_hi_name][0][-1])), "valid"
mat_vals, xbins, ybins = get_2d_hist_from_table(x_prefix, y_prefix, table, column)
xcent = get_bin_centers(xbins)
rebin_x, rebin_xcent, rebin_hist = get_x_bin_values_with_rebinning(
num_rebin, xbins, xcent, mat_vals, density
)
if not x_label:
x_label = x_prefix
if not y_label:
y_label = y_prefix
if isinstance(column, str):
mat_vals = np.squeeze(table[column])
else:
mat_vals = column

if num_rebin > 1:
rebin_x, rebin_xcent, rebin_hist = rebin_x_2D_hist(
mat_vals, xbins, xcent, num_bins_merge=num_rebin
)
density = False
else:
rebin_x, rebin_xcent, rebin_hist = xbins, xcent, mat_vals

stats = find_columnwise_stats(rebin_hist, ybins, quantiles, density)
plot = plot_hist2D(
plot = plot_hist2d(
ax,
rebin_hist,
rebin_x,
Expand All @@ -127,36 +72,25 @@ def plot_2D_table_with_col_stats(
)
plt.colorbar(plot)

sel = stats[:, 0] > 0
if stat_kind == 1:
y_idx = 0
err = stats[sel, 2]
label = "mean + std"
if stat_kind == 2:
y_idx = 1
err = stats[sel, 2]
label = "median + std"
if stat_kind == 3:
y_idx = 1
err = np.zeros_like(stats[:, 3:])
err[sel, 0] = stats[sel, 1] - stats[sel, 3]
err[sel, 1] = stats[sel, 4] - stats[sel, 1]
err = err[sel, :].T
label = f"median + IRQ[{quantiles[0]:.2f},{quantiles[1]:.2f}]"

ax.errorbar(
x=rebin_xcent[sel],
y=stats[sel, y_idx],
yerr=err,
label=label,
**mpl_args["stats"],
ax = plot_2d_table_col_stats(
ax,
table,
column,
x_prefix,
y_prefix,
num_rebin,
stat_kind,
quantiles,
x_label,
y_label,
density,
mpl_args,
lbl_prefix="",
)
ax.legend(loc="best")

return ax


def plot_2D_table_col_stats(
def plot_2d_table_col_stats(
ax,
table,
column,
Expand All @@ -171,38 +105,18 @@ def plot_2D_table_col_stats(
lbl_prefix="",
mpl_args={"xscale": "log"},
):
"""Function to draw columnwise statistics of 2D hist
"""Function to draw columnwise statistics of 2d hist
the content values shown depending on stat_kind:
0 -> mean + standard deviation
1 -> median + standard deviation
2 -> median + user specified quantiles around median (default 0.1 to 0.9)
"""
x_lo_name, x_hi_name = f"{x_prefix}_LO", f"{x_prefix}_HI"
y_lo_name, y_hi_name = f"{y_prefix}_LO", f"{y_prefix}_HI"

xbins = np.hstack((table[x_lo_name][0], table[x_hi_name][0][-1]))

ybins = np.hstack((table[y_lo_name][0], table[y_hi_name][0][-1]))

xcent = np.convolve(
[0.5, 0.5], np.hstack((table[x_lo_name][0], table[x_hi_name][0][-1])), "valid"
mat_vals, xbins, ybins = get_2d_hist_from_table(x_prefix, y_prefix, table, column)
xcent = get_bin_centers(xbins)
rebin_x, rebin_xcent, rebin_hist = get_x_bin_values_with_rebinning(
num_rebin, xbins, xcent, mat_vals, density
)
if not x_label:
x_label = x_prefix
if not y_label:
y_label = y_prefix
if isinstance(column, str):
mat_vals = np.squeeze(table[column])
else:
mat_vals = column

if num_rebin > 1:
rebin_x, rebin_xcent, rebin_hist = rebin_x_2D_hist(
mat_vals, xbins, xcent, num_bins_merge=num_rebin
)
density = False
else:
rebin_xcent, rebin_hist = xcent, mat_vals

stats = find_columnwise_stats(rebin_hist, ybins, quantiles, density)

Expand Down Expand Up @@ -262,7 +176,7 @@ def plot_irf_table(
ax.stairs(vals, bins, label=label, **mpl_args)


def plot_hist2D_as_contour(
def plot_hist2d_as_contour(
ax,
hist,
xedges,
Expand All @@ -283,7 +197,7 @@ def plot_hist2D_as_contour(
return out


def plot_hist2D(
def plot_hist2d(
ax,
hist,
xedges,
Expand Down

0 comments on commit d6d4139

Please sign in to comment.