Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show spatial residuals for sceence frames - tickets/PIPE2D-1661 #52

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 116 additions & 122 deletions python/pfs/drp/qa/dmCombinedResiduals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
Input as InputConnection,
Output as OutputConnection,
)
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from pandas import DataFrame
from pfs.drp.stella import DetectorMap

from pfs.drp.qa.dmResiduals import plot_detectormap_residuals
Expand Down Expand Up @@ -122,8 +120,8 @@ def runQuantum(
def run(
self,
detectorMaps: Iterable[DetectorMap],
dmQaResidualData: Iterable[DataFrame],
dmQaResidualStats: Iterable[DataFrame],
dmQaResidualData: Iterable[pd.DataFrame],
dmQaResidualStats: Iterable[pd.DataFrame],
run_name: str,
) -> Struct:
"""Create detector level residual_stats and plots.
Expand Down Expand Up @@ -177,8 +175,8 @@ def get_ccd(dm: DetectorMap) -> str:


def make_report(
residual_stats: DataFrame,
arc_data: DataFrame,
residual_stats: pd.DataFrame,
residual_data: pd.DataFrame,
detectorMaps: Dict[str, DetectorMap],
run_name: str,
log: object,
Expand All @@ -190,9 +188,6 @@ def make_report(
# Add the title as a figure.
pdf.append(plot_title(run_name))

# Add the table data as a figure.
pdf.append(plot_dataframe(reserved_stats))

# Detector summaries.
log.info("Making detector summary plots")
pdf.append(plot_detector_summary(reserved_stats))
Expand Down Expand Up @@ -224,19 +219,19 @@ def make_report(
# Add the 2D residual plot.
arm = ccd[0]
spec = int(ccd[1])
data = arc_data.query(f"arm == '{arm}' and spectrograph == {spec}")
plot_data = residual_data.query(f"arm == '{arm}' and spectrograph == {spec}")

# If we are doing a combined report we want to get the mean across visits.
grouped = data[plot_cols].groupby(["status", "isLine", "fiberId", "y"])
data = grouped.mean().reset_index()
grouped = plot_data[plot_cols].groupby(["status", "isLine", "fiberId", "y"])
plot_data = grouped.mean().reset_index()

residFig = plot_detectormap_residuals(data, visit_stats, detectorMaps[str(ccd)])
residFig = plot_detectormap_residuals(plot_data, visit_stats, detectorMaps[str(ccd)])
residFig.suptitle(f"DetectorMap Residuals - Median of all visits - {ccd}", weight="bold")
pdf.append(residFig, dpi=150)

# Add the description per visit breakdown.
fig = plot_detector_visits(visit_stats)
fig.suptitle(f"{fig.get_suptitle()}\n{ccd}")
fig = plot_visits(visit_stats.query('status_type == "RESERVED"'), palette=description_palette)
fig.suptitle(f"{fig.get_suptitle()} - {ccd}")
pdf.append(fig)
except KeyError:
log.warning(f"DetectorMap not found for {ccd}. Skipping.")
Expand All @@ -247,86 +242,86 @@ def make_report(
return pdf


def plot_detector_visits(plot_data: DataFrame) -> Figure:
summary_stats = plot_data.filter(regex="median|weighted").mean().to_dict()

fig = plot_visits(plot_data, palette=description_palette)

for ax, dim in zip(fig.axes, ["spatial", "wavelength"]):
upper_range = summary_stats[f"{dim}.median"] + summary_stats[f"{dim}.weightedRms"]
lower_range = summary_stats[f"{dim}.median"] - summary_stats[f"{dim}.weightedRms"]

ax.axvline(summary_stats[f"{dim}.median"], c="k", ls="--")
ax.axvline(upper_range, c="g", ls="--")
ax.axvline(lower_range, c="g", ls="--")
ax.set_title(
f"{dim.upper()}: "
f'median={summary_stats[f"{dim}.median"]:5.04f} '
f'rms={summary_stats[f"{dim}.weightedRms"]:5.04f}'
)

fig.set_size_inches(8, 8)

return fig


def plot_detector_summary(stats: DataFrame) -> Figure:
def plot_detector_summary(stats: pd.DataFrame) -> Figure:
plot_data_spatial = (
stats.query("description == 'Trace'")
.filter(regex="ccd|median|weighted|soften")
.filter(regex="ccd|spatial.(median|weighted|soften)")
.groupby("ccd", observed=False)
.mean()
)
plot_data_spatial.columns = [c.replace("spatial.", "") for c in plot_data_spatial.columns]
plot_data_wavelength = (
stats.query("description != 'Trace'")
.filter(regex="ccd|median|weighted|soften")
.filter(regex="ccd|wavelength.(median|weighted|soften)")
.groupby("ccd", observed=False)
.mean()
)
plot_data_wavelength.columns = [c.replace("wavelength.", "") for c in plot_data_wavelength.columns]

fig, (ax0, ax1) = plt.subplots(ncols=2, sharey=True, layout="constrained")
fig.set_size_inches(12, 4)
fig = Figure(figsize=(11, 8), layout="constrained")
spatial_plot_ax = fig.add_subplot(221)
wavelength_plot_ax = fig.add_subplot(222, sharey=spatial_plot_ax)
spatial_table_ax = fig.add_subplot(223, sharex=spatial_plot_ax)
wavelength_table_ax = fig.add_subplot(224, sharex=wavelength_plot_ax, sharey=spatial_table_ax)

formatter = "{:5.04f}".format

# Plot the spatial median and weightedRms and show table below.
for ccd, row in plot_data_spatial.iterrows():
ax0.errorbar(
spatial_plot_ax.errorbar(
x=ccd,
y=row["spatial.median"],
yerr=row["spatial.weightedRms"],
markersize=max(row["spatial.softenFit"].mean() * 100, 1),
y=row["median"],
yerr=row["weightedRms"],
markersize=max(row["softenFit"].mean() * 100, 1),
marker="o",
mec="k",
linewidth=2,
capsize=2,
color=detector_palette[ccd[0]],
)

spatial_table = pd.plotting.table(spatial_table_ax, plot_data_spatial.map(formatter), loc="center")
spatial_table.set_fontsize(11)
spatial_table_ax.set_title("Spatial median and weightedRms error (quartz only)")

for ccd, row in plot_data_wavelength.iterrows():
ax1.errorbar(
wavelength_plot_ax.errorbar(
x=ccd,
y=row["wavelength.median"],
yerr=row["wavelength.weightedRms"],
markersize=max(row["wavelength.softenFit"].mean() * 100, 1),
y=row["median"],
yerr=row["weightedRms"],
markersize=max(row["softenFit"].mean() * 100, 1),
marker="o",
markeredgecolor="k",
linewidth=2,
capsize=2,
color=detector_palette[ccd[0]],
)

ax0.axhline(0, c="k", ls="--", alpha=0.3)
ax0.set_title("Spatial median and weightedRms error (quartz only)")
ax0.set_ylabel("Median (pixel)")
wavelength_table = pd.plotting.table(
wavelength_table_ax, plot_data_wavelength.map(formatter), loc="center"
)
wavelength_table.set_fontsize(11)
wavelength_table_ax.set_title("Wavelength median and weightedRms")

# Set the titles and labels
spatial_plot_ax.axhline(0, c="k", ls="--", alpha=0.3)
spatial_plot_ax.set_title("Spatial median and weightedRms error (quartz only)")
spatial_plot_ax.set_ylabel("Median (pixel)")

ax1.axhline(0, c="k", ls="--", alpha=0.3)
ax1.set_title("Wavelength median and weightedRms")
wavelength_plot_ax.axhline(0, c="k", ls="--", alpha=0.3)
wavelength_plot_ax.set_title("Wavelength median and weightedRms")

ax0.grid(True, color="k", linestyle="--", alpha=0.15)
ax1.grid(True, color="k", linestyle="--", alpha=0.15)
spatial_plot_ax.grid(True, color="k", linestyle="--", alpha=0.15)
wavelength_plot_ax.grid(True, color="k", linestyle="--", alpha=0.15)
spatial_table_ax.set_axis_off()
wavelength_table_ax.set_axis_off()

fig.suptitle("DetectorMap Residuals Summary", y=1.05)

return fig


def plot_detector_summary_per_desc(stats: DataFrame) -> Figure:
def plot_detector_summary_per_desc(stats: pd.DataFrame) -> Figure:
plot_data = (
stats.set_index(["ccd", "description"])
.filter(regex="median|weighted|soften")
Expand Down Expand Up @@ -356,7 +351,7 @@ def plot_detector_summary_per_desc(stats: DataFrame) -> Figure:
aspect=2.5,
flierprops={"marker": ".", "ms": 2},
)
fg.fig.suptitle("DetectorMap Residuals by description", y=1)
fg.figure.suptitle("DetectorMap Residuals by description", y=1)
for i, ax in enumerate(fg.figure.axes):
ax.set_ylabel("Median residual (pixel)")
if i == 0:
Expand Down Expand Up @@ -385,12 +380,12 @@ def plot_detector_summary_per_desc(stats: DataFrame) -> Figure:
alpha=0.5,
)

return fg.fig
return fg.figure


def plot_visits(
plotData: pd.DataFrame,
palette: Optional[dict] = None,
palette: Optional[dict | list] = None,
spatialRange: float = 0.1,
wavelengthRange: float = 0.1,
fig: Optional[Figure] = None,
Expand Down Expand Up @@ -419,34 +414,81 @@ def plot_visits(
"""
plotData = plotData.copy()
fig = fig or Figure(layout="constrained")
fig.set_size_inches(11, 8)
ax0 = fig.add_subplot(121)
ax1 = fig.add_subplot(122, sharex=ax0, sharey=ax0)

plotData["visit_idx"] = plotData.visit.rank(method="first")

palette = palette or description_palette

for ax, metric in zip([ax0, ax1], ["spatial", "wavelength"]):
for desc, grp in plotData.groupby("description"):
grp.plot.scatter(
y="visit_idx",
x=f"{metric}.median",
xerr=f"{metric}.weightedRms",
metricData = plotData.copy()
if metric == "spatial":
metricData = metricData.query("description == 'Trace'")
else:
metricData = metricData.query("description != 'Trace'")

for desc, grp in metricData.groupby("description"):
grpPlotData = grp.copy()
ax.errorbar(
y=grpPlotData["visit_idx"],
x=grpPlotData[f"{metric}.median"],
xerr=grpPlotData[f"{metric}.weightedRms"],
marker="o",
color=palette.get(desc, "red") if palette is not None else None,
ms=10,
elinewidth=2,
capsize=4,
mec="w",
ls="",
color=palette.get(desc, "black"),
label=desc,
ax=ax,
zorder=110,
)

ax.grid(alpha=0.2)
ax.axvline(0, c="k", ls="--", alpha=0.5)
# Mark the median and 1-sigma range across the visits.
summary_stats = metricData.filter(regex="median|weighted").median().to_dict()
upper_range = summary_stats[f"{metric}.median"] + summary_stats[f"{metric}.weightedRms"]
lower_range = summary_stats[f"{metric}.median"] - summary_stats[f"{metric}.weightedRms"]

ax.axvline(summary_stats[f"{metric}.median"], c="k", ls="--", zorder=101)
ax.axvline(upper_range, c="g", ls="--", zorder=101)
ax.axvline(lower_range, c="g", ls="--", zorder=101)
ax.set_title(
f"{metric.upper()}: "
f'median={summary_stats[f"{metric}.median"]:5.04f} '
f'rms={summary_stats[f"{metric}.weightedRms"]:5.04f}'
)

ax.grid(which="major", color="k", axis="y", zorder=-100)
ax.axvline(0, c="k", ls="-", alpha=0.75)
ax.set_title(f"{metric}")
ax.set_xlabel("pix")

leg = ax.legend(loc="upper right", shadow=True)
leg.set_zorder(1000)

if spatialRange is not None and metric == "spatial":
ax.set_xlim(-spatialRange, spatialRange)
if wavelengthRange is not None and metric == "wavelength":
ax.set_xlim(-wavelengthRange, wavelengthRange)

visit_label = [f"{row.visit}" for idx, row in plotData.iterrows()]
ax0.set_yticks(plotData.visit_idx, visit_label, fontsize="xx-small")
# Only label a visit the first time it's seen.
labeled_ticks = set()
all_ticks = list()
visit_idx = list()
for idx, row in plotData.reset_index().iterrows():
if row.visit not in labeled_ticks:
all_ticks.append(f"{row.visit}")
labeled_ticks.add(row.visit)
visit_idx.append(idx + 1)

# Create a striped background to offset the visits.
ax0.set_yticks(visit_idx, labeled_ticks, fontsize="xx-small")
for i, (y0, y1) in enumerate(itertools.pairwise(visit_idx)):
ax0.axhspan(y0, y1, color="whitesmoke" if i % 2 == 0 else "ivory", alpha=0.5)
ax1.axhspan(y0, y1, color="whitesmoke" if i % 2 == 0 else "ivory", alpha=0.5)

ax0.set_ylabel("Visit")
ax0.invert_yaxis()

Expand All @@ -455,57 +497,9 @@ def plot_visits(
return fig


def plot_dataframe(stats: DataFrame) -> Figure:
"""Plot the residual data frame.

Parameters
----------
stats : `pandas.DataFrame`
The data frame to plot.

Returns
-------
fig : `Figure`
The figure.
"""
plot_data_spatial = (
stats.query("description == 'Trace'")
.filter(regex="ccd|spatial.(median|weighted|soften)")
.groupby("ccd", observed=False)
.mean()
)
plot_data_spatial.columns = [c.replace("spatial.", "") for c in plot_data_spatial.columns]
plot_data_wavelength = (
stats.query("description != 'Trace'")
.filter(regex="ccd|wavelength.(median|weighted|soften)")
.groupby("ccd", observed=False)
.mean()
)
plot_data_wavelength.columns = [c.replace("wavelength.", "") for c in plot_data_wavelength.columns]

formatter = "{:5.04f}".format

fig = Figure(layout="constrained")
ax0 = fig.add_subplot(211)
ax1 = fig.add_subplot(212)
ax0.set_axis_off()
ax1.set_axis_off()
t0 = pd.plotting.table(ax0, plot_data_spatial.map(formatter), loc="center")
t0.set_fontsize(11)
t1 = pd.plotting.table(ax1, plot_data_wavelength.map(formatter), loc="center")
t1.set_fontsize(11)

ax0.set_title("Spatial (quartz only)", y=1.12)
ax1.set_title("Wavelength", y=1.12)

fig.suptitle("Residuals summary", y=1.15)

return fig


def plot_title(run_name: str) -> Figure:
"""Plot a title page for the combined report."""
fig = Figure()
fig = Figure(figsize=(11, 8))
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.text(0.5, 0.5, "DetectorMap Residuals Summary", ha="center", va="center", fontsize="large")
Expand Down
Loading