From cf4f3d1f4033e8aa46e23d04328614cc68106e06 Mon Sep 17 00:00:00 2001 From: Wilfred Tyler Gee Date: Tue, 4 Mar 2025 09:07:16 -1000 Subject: [PATCH] Using the `raw.visitInfo.observationReason` to determine if `isScience`. Helps with plotting the spatial but only for science. Small change to make `dataId` a required param. Minor cleanup to some plotting code. Split the stats up by description as well as by status_type. Better arc line description color palette, including a default fall-back Alternate row colors to make visits easier to distinguish. Make the size of the per-visit plot match the 2d detector residual plot size. Use only the RESERVED data for the visit stats. Set a consistent figure size for the combined report. Combine the summary plot and table. Remove `plot_dataframe` Get the median for each of the visit stats rather than simply the first record. Also only pass the Trace for the spatial. Update detector summary group visit plot Fix the fiberId number and lines num count Fix the soften values and counts. Mark outliers by `description` but compute stats by `isTrace` Fix the legend for the combined plots --- python/pfs/drp/qa/dmCombinedResiduals.py | 238 +++++++++++------------ python/pfs/drp/qa/dmResiduals.py | 179 +++++++++-------- python/pfs/drp/qa/utils/plotting.py | 25 ++- 3 files changed, 222 insertions(+), 220 deletions(-) diff --git a/python/pfs/drp/qa/dmCombinedResiduals.py b/python/pfs/drp/qa/dmCombinedResiduals.py index 3c3ad912..e5c3f6b0 100644 --- a/python/pfs/drp/qa/dmCombinedResiduals.py +++ b/python/pfs/drp/qa/dmCombinedResiduals.py @@ -18,9 +18,7 @@ Input as InputConnection, Output as OutputConnection, ) -from matplotlib import pyplot as plt from matplotlib.figure import Figure -from pandas import DataFrame from pfs.drp.stella import DetectorMap from pfs.drp.qa.dmResiduals import plot_detectormap_residuals @@ -122,8 +120,8 @@ def runQuantum( def run( self, detectorMaps: Iterable[DetectorMap], - dmQaResidualData: Iterable[DataFrame], - dmQaResidualStats: Iterable[DataFrame], + dmQaResidualData: Iterable[pd.DataFrame], + dmQaResidualStats: Iterable[pd.DataFrame], run_name: str, ) -> Struct: """Create detector level residual_stats and plots. @@ -177,8 +175,8 @@ def get_ccd(dm: DetectorMap) -> str: def make_report( - residual_stats: DataFrame, - arc_data: DataFrame, + residual_stats: pd.DataFrame, + residual_data: pd.DataFrame, detectorMaps: Dict[str, DetectorMap], run_name: str, log: object, @@ -190,9 +188,6 @@ def make_report( # Add the title as a figure. pdf.append(plot_title(run_name)) - # Add the table data as a figure. - pdf.append(plot_dataframe(reserved_stats)) - # Detector summaries. log.info("Making detector summary plots") pdf.append(plot_detector_summary(reserved_stats)) @@ -224,19 +219,19 @@ def make_report( # Add the 2D residual plot. arm = ccd[0] spec = int(ccd[1]) - data = arc_data.query(f"arm == '{arm}' and spectrograph == {spec}") + plot_data = residual_data.query(f"arm == '{arm}' and spectrograph == {spec}") # If we are doing a combined report we want to get the mean across visits. - grouped = data[plot_cols].groupby(["status", "isLine", "fiberId", "y"]) - data = grouped.mean().reset_index() + grouped = plot_data[plot_cols].groupby(["status", "isLine", "fiberId", "y"]) + plot_data = grouped.mean().reset_index() - residFig = plot_detectormap_residuals(data, visit_stats, detectorMaps[str(ccd)]) + residFig = plot_detectormap_residuals(plot_data, visit_stats, detectorMaps[str(ccd)]) residFig.suptitle(f"DetectorMap Residuals - Median of all visits - {ccd}", weight="bold") pdf.append(residFig, dpi=150) # Add the description per visit breakdown. - fig = plot_detector_visits(visit_stats) - fig.suptitle(f"{fig.get_suptitle()}\n{ccd}") + fig = plot_visits(visit_stats.query('status_type == "RESERVED"'), palette=description_palette) + fig.suptitle(f"{fig.get_suptitle()} - {ccd}") pdf.append(fig) except KeyError: log.warning(f"DetectorMap not found for {ccd}. Skipping.") @@ -247,52 +242,37 @@ def make_report( return pdf -def plot_detector_visits(plot_data: DataFrame) -> Figure: - summary_stats = plot_data.filter(regex="median|weighted").mean().to_dict() - - fig = plot_visits(plot_data, palette=description_palette) - - for ax, dim in zip(fig.axes, ["spatial", "wavelength"]): - upper_range = summary_stats[f"{dim}.median"] + summary_stats[f"{dim}.weightedRms"] - lower_range = summary_stats[f"{dim}.median"] - summary_stats[f"{dim}.weightedRms"] - - ax.axvline(summary_stats[f"{dim}.median"], c="k", ls="--") - ax.axvline(upper_range, c="g", ls="--") - ax.axvline(lower_range, c="g", ls="--") - ax.set_title( - f"{dim.upper()}: " - f'median={summary_stats[f"{dim}.median"]:5.04f} ' - f'rms={summary_stats[f"{dim}.weightedRms"]:5.04f}' - ) - - fig.set_size_inches(8, 8) - - return fig - - -def plot_detector_summary(stats: DataFrame) -> Figure: +def plot_detector_summary(stats: pd.DataFrame) -> Figure: plot_data_spatial = ( stats.query("description == 'Trace'") - .filter(regex="ccd|median|weighted|soften") + .filter(regex="ccd|spatial.(median|weighted|soften)") .groupby("ccd", observed=False) .mean() ) + plot_data_spatial.columns = [c.replace("spatial.", "") for c in plot_data_spatial.columns] plot_data_wavelength = ( stats.query("description != 'Trace'") - .filter(regex="ccd|median|weighted|soften") + .filter(regex="ccd|wavelength.(median|weighted|soften)") .groupby("ccd", observed=False) .mean() ) + plot_data_wavelength.columns = [c.replace("wavelength.", "") for c in plot_data_wavelength.columns] - fig, (ax0, ax1) = plt.subplots(ncols=2, sharey=True, layout="constrained") - fig.set_size_inches(12, 4) + fig = Figure(figsize=(11, 8), layout="constrained") + spatial_plot_ax = fig.add_subplot(221) + wavelength_plot_ax = fig.add_subplot(222, sharey=spatial_plot_ax) + spatial_table_ax = fig.add_subplot(223, sharex=spatial_plot_ax) + wavelength_table_ax = fig.add_subplot(224, sharex=wavelength_plot_ax, sharey=spatial_table_ax) + formatter = "{:5.04f}".format + + # Plot the spatial median and weightedRms and show table below. for ccd, row in plot_data_spatial.iterrows(): - ax0.errorbar( + spatial_plot_ax.errorbar( x=ccd, - y=row["spatial.median"], - yerr=row["spatial.weightedRms"], - markersize=max(row["spatial.softenFit"].mean() * 100, 1), + y=row["median"], + yerr=row["weightedRms"], + markersize=max(row["softenFit"].mean() * 100, 1), marker="o", mec="k", linewidth=2, @@ -300,12 +280,16 @@ def plot_detector_summary(stats: DataFrame) -> Figure: color=detector_palette[ccd[0]], ) + spatial_table = pd.plotting.table(spatial_table_ax, plot_data_spatial.map(formatter), loc="center") + spatial_table.set_fontsize(11) + spatial_table_ax.set_title("Spatial median and weightedRms error (quartz only)") + for ccd, row in plot_data_wavelength.iterrows(): - ax1.errorbar( + wavelength_plot_ax.errorbar( x=ccd, - y=row["wavelength.median"], - yerr=row["wavelength.weightedRms"], - markersize=max(row["wavelength.softenFit"].mean() * 100, 1), + y=row["median"], + yerr=row["weightedRms"], + markersize=max(row["softenFit"].mean() * 100, 1), marker="o", markeredgecolor="k", linewidth=2, @@ -313,20 +297,31 @@ def plot_detector_summary(stats: DataFrame) -> Figure: color=detector_palette[ccd[0]], ) - ax0.axhline(0, c="k", ls="--", alpha=0.3) - ax0.set_title("Spatial median and weightedRms error (quartz only)") - ax0.set_ylabel("Median (pixel)") + wavelength_table = pd.plotting.table( + wavelength_table_ax, plot_data_wavelength.map(formatter), loc="center" + ) + wavelength_table.set_fontsize(11) + wavelength_table_ax.set_title("Wavelength median and weightedRms") + + # Set the titles and labels + spatial_plot_ax.axhline(0, c="k", ls="--", alpha=0.3) + spatial_plot_ax.set_title("Spatial median and weightedRms error (quartz only)") + spatial_plot_ax.set_ylabel("Median (pixel)") - ax1.axhline(0, c="k", ls="--", alpha=0.3) - ax1.set_title("Wavelength median and weightedRms") + wavelength_plot_ax.axhline(0, c="k", ls="--", alpha=0.3) + wavelength_plot_ax.set_title("Wavelength median and weightedRms") - ax0.grid(True, color="k", linestyle="--", alpha=0.15) - ax1.grid(True, color="k", linestyle="--", alpha=0.15) + spatial_plot_ax.grid(True, color="k", linestyle="--", alpha=0.15) + wavelength_plot_ax.grid(True, color="k", linestyle="--", alpha=0.15) + spatial_table_ax.set_axis_off() + wavelength_table_ax.set_axis_off() + + fig.suptitle("DetectorMap Residuals Summary", y=1.05) return fig -def plot_detector_summary_per_desc(stats: DataFrame) -> Figure: +def plot_detector_summary_per_desc(stats: pd.DataFrame) -> Figure: plot_data = ( stats.set_index(["ccd", "description"]) .filter(regex="median|weighted|soften") @@ -356,7 +351,7 @@ def plot_detector_summary_per_desc(stats: DataFrame) -> Figure: aspect=2.5, flierprops={"marker": ".", "ms": 2}, ) - fg.fig.suptitle("DetectorMap Residuals by description", y=1) + fg.figure.suptitle("DetectorMap Residuals by description", y=1) for i, ax in enumerate(fg.figure.axes): ax.set_ylabel("Median residual (pixel)") if i == 0: @@ -385,12 +380,12 @@ def plot_detector_summary_per_desc(stats: DataFrame) -> Figure: alpha=0.5, ) - return fg.fig + return fg.figure def plot_visits( plotData: pd.DataFrame, - palette: Optional[dict] = None, + palette: Optional[dict | list] = None, spatialRange: float = 0.1, wavelengthRange: float = 0.1, fig: Optional[Figure] = None, @@ -419,34 +414,81 @@ def plot_visits( """ plotData = plotData.copy() fig = fig or Figure(layout="constrained") + fig.set_size_inches(11, 8) ax0 = fig.add_subplot(121) ax1 = fig.add_subplot(122, sharex=ax0, sharey=ax0) plotData["visit_idx"] = plotData.visit.rank(method="first") + palette = palette or description_palette + for ax, metric in zip([ax0, ax1], ["spatial", "wavelength"]): - for desc, grp in plotData.groupby("description"): - grp.plot.scatter( - y="visit_idx", - x=f"{metric}.median", - xerr=f"{metric}.weightedRms", + metricData = plotData.copy() + if metric == "spatial": + metricData = metricData.query("description == 'Trace'") + else: + metricData = metricData.query("description != 'Trace'") + + for desc, grp in metricData.groupby("description"): + grpPlotData = grp.copy() + ax.errorbar( + y=grpPlotData["visit_idx"], + x=grpPlotData[f"{metric}.median"], + xerr=grpPlotData[f"{metric}.weightedRms"], marker="o", - color=palette.get(desc, "red") if palette is not None else None, + ms=10, + elinewidth=2, + capsize=4, + mec="w", + ls="", + color=palette.get(desc, "black"), label=desc, - ax=ax, + zorder=110, ) - ax.grid(alpha=0.2) - ax.axvline(0, c="k", ls="--", alpha=0.5) + # Mark the median and 1-sigma range across the visits. + summary_stats = metricData.filter(regex="median|weighted").median().to_dict() + upper_range = summary_stats[f"{metric}.median"] + summary_stats[f"{metric}.weightedRms"] + lower_range = summary_stats[f"{metric}.median"] - summary_stats[f"{metric}.weightedRms"] + + ax.axvline(summary_stats[f"{metric}.median"], c="k", ls="--", zorder=101) + ax.axvline(upper_range, c="g", ls="--", zorder=101) + ax.axvline(lower_range, c="g", ls="--", zorder=101) + ax.set_title( + f"{metric.upper()}: " + f'median={summary_stats[f"{metric}.median"]:5.04f} ' + f'rms={summary_stats[f"{metric}.weightedRms"]:5.04f}' + ) + + ax.grid(which="major", color="k", axis="y", zorder=-100) + ax.axvline(0, c="k", ls="-", alpha=0.75) ax.set_title(f"{metric}") ax.set_xlabel("pix") + + leg = ax.legend(loc="upper right", shadow=True) + leg.set_zorder(1000) + if spatialRange is not None and metric == "spatial": ax.set_xlim(-spatialRange, spatialRange) if wavelengthRange is not None and metric == "wavelength": ax.set_xlim(-wavelengthRange, wavelengthRange) - visit_label = [f"{row.visit}" for idx, row in plotData.iterrows()] - ax0.set_yticks(plotData.visit_idx, visit_label, fontsize="xx-small") + # Only label a visit the first time it's seen. + labeled_ticks = set() + all_ticks = list() + visit_idx = list() + for idx, row in plotData.reset_index().iterrows(): + if row.visit not in labeled_ticks: + all_ticks.append(f"{row.visit}") + labeled_ticks.add(row.visit) + visit_idx.append(idx + 1) + + # Create a striped background to offset the visits. + ax0.set_yticks(visit_idx, labeled_ticks, fontsize="xx-small") + for i, (y0, y1) in enumerate(itertools.pairwise(visit_idx)): + ax0.axhspan(y0, y1, color="whitesmoke" if i % 2 == 0 else "ivory", alpha=0.5) + ax1.axhspan(y0, y1, color="whitesmoke" if i % 2 == 0 else "ivory", alpha=0.5) + ax0.set_ylabel("Visit") ax0.invert_yaxis() @@ -455,57 +497,9 @@ def plot_visits( return fig -def plot_dataframe(stats: DataFrame) -> Figure: - """Plot the residual data frame. - - Parameters - ---------- - stats : `pandas.DataFrame` - The data frame to plot. - - Returns - ------- - fig : `Figure` - The figure. - """ - plot_data_spatial = ( - stats.query("description == 'Trace'") - .filter(regex="ccd|spatial.(median|weighted|soften)") - .groupby("ccd", observed=False) - .mean() - ) - plot_data_spatial.columns = [c.replace("spatial.", "") for c in plot_data_spatial.columns] - plot_data_wavelength = ( - stats.query("description != 'Trace'") - .filter(regex="ccd|wavelength.(median|weighted|soften)") - .groupby("ccd", observed=False) - .mean() - ) - plot_data_wavelength.columns = [c.replace("wavelength.", "") for c in plot_data_wavelength.columns] - - formatter = "{:5.04f}".format - - fig = Figure(layout="constrained") - ax0 = fig.add_subplot(211) - ax1 = fig.add_subplot(212) - ax0.set_axis_off() - ax1.set_axis_off() - t0 = pd.plotting.table(ax0, plot_data_spatial.map(formatter), loc="center") - t0.set_fontsize(11) - t1 = pd.plotting.table(ax1, plot_data_wavelength.map(formatter), loc="center") - t1.set_fontsize(11) - - ax0.set_title("Spatial (quartz only)", y=1.12) - ax1.set_title("Wavelength", y=1.12) - - fig.suptitle("Residuals summary", y=1.15) - - return fig - - def plot_title(run_name: str) -> Figure: """Plot a title page for the combined report.""" - fig = Figure() + fig = Figure(figsize=(11, 8)) ax = fig.add_subplot(111) ax.set_axis_off() ax.text(0.5, 0.5, "DetectorMap Residuals Summary", ha="center", va="center", fontsize="large") diff --git a/python/pfs/drp/qa/dmResiduals.py b/python/pfs/drp/qa/dmResiduals.py index e998271c..a378a847 100644 --- a/python/pfs/drp/qa/dmResiduals.py +++ b/python/pfs/drp/qa/dmResiduals.py @@ -1,5 +1,4 @@ import warnings -from contextlib import suppress from dataclasses import dataclass from functools import partial from logging import Logger @@ -8,6 +7,7 @@ import numpy as np import pandas as pd from astropy.stats import sigma_clip +from lsst.afw.image import VisitInfo from lsst.pex.config import Config, Field from lsst.pipe.base import ( InputQuantizedConnection, @@ -38,37 +38,29 @@ class DetectorMapResidualsConnections( PipelineTaskConnections, - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ): """Connections for DetectorMapQaTask""" + visitInfo = InputConnection( + name="raw.visitInfo", + doc="Visit info from the raw exposure", + storageClass="VisitInfo", + dimensions=("instrument", "visit", "arm", "spectrograph"), + ) + detectorMap = InputConnection( name="detectorMap", doc="Adjusted detector mapping from fiberId,wavelength to x,y", storageClass="DetectorMap", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) arcLines = InputConnection( name="lines", doc="Emission line measurements", storageClass="ArcLineSet", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) reduceExposure_config = InputConnection( name="reduceExposure_config", @@ -76,38 +68,24 @@ class DetectorMapResidualsConnections( storageClass="Config", dimensions=(), ) + dmQaResidualData = OutputConnection( name="dmQaResidualData", doc="The dataframe of the detectormap residuals.", storageClass="DataFrame", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) dmQaResidualStats = OutputConnection( name="dmQaResidualStats", doc="Statistics of the DM residual analysis.", storageClass="DataFrame", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) dmQaResidualPlot = OutputConnection( name="dmQaResidualPlot", doc="The 1D and 2D residual plots of the detectormap with the arclines for a given visit.", storageClass="Plot", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) @@ -159,10 +137,11 @@ def run( self, arcLines: ArcLineSet, detectorMap: DetectorMap, + visitInfo: VisitInfo, + dataId: dict, dropNaColumns: bool = True, removeOutliers: bool = True, addFiberInfo: bool = True, - dataId: dict = None, reduceExposure_config: Config = None, **kwargs, ) -> Struct: @@ -178,14 +157,17 @@ def run( The arc lines. detectorMap : `DetectorMap` The detector map. + visitInfo : `VisitInfo` + The visit info containing the observationReason, which determines + some plotting parameters. + dataId : `dict` + The dataId for the visit. dropNaColumns : `bool`, optional Drop columns where all values are NaN. Default is True. removeOutliers : `bool`, optional Remove rows with ``flag=False``? Default is True. addFiberInfo : `bool`, optional Add fiber information to the dataframe. Default is True. - dataId : dict, optional - Dictionary of the dataId. reduceExposure_config : `Config`, optional Configuration for reduceExposure. @@ -202,6 +184,7 @@ def run( dataId, arcLines, detectorMap, + visitInfo, adjustDM_config=adjustDM_config, log=self.log, ) @@ -266,30 +249,47 @@ def to_dict(self): @classmethod def from_dataframe(cls, df: pd.DataFrame): """Convert from dataframe to FitStats.""" - reserved_wl = df.filter(like="wavelength.").copy() - reserved_spatial = df.filter(like="spatial.").copy() - - reserved_wl.columns = reserved_wl.columns.str.rsplit(".", n=1).str[-1] - reserved_spatial.columns = reserved_spatial.columns.str.rsplit(".", n=1).str[-1] - - rec = df.iloc[0] - return cls( - dof=rec.dof, - chi2X=rec.chi2X, - chi2Y=rec.chi2Y, - spatial=FitStat(*reserved_spatial.iloc[0].to_list()), - wavelength=FitStat(*reserved_wl.iloc[0].to_list()), - ) + try: + df = df.select_dtypes(include="number").median().to_frame().T + + reserved_wl = df.filter(like="wavelength.").copy() + reserved_spatial = df.filter(like="spatial.").copy() + + reserved_wl.columns = reserved_wl.columns.str.rsplit(".", n=1).str[-1] + reserved_spatial.columns = reserved_spatial.columns.str.rsplit(".", n=1).str[-1] + + rec = df.iloc[0] + fs = cls( + dof=rec.dof, + chi2X=rec.chi2X, + chi2Y=rec.chi2Y, + spatial=FitStat(*reserved_spatial.iloc[0].to_list()), + wavelength=FitStat(*reserved_wl.iloc[0].to_list()), + ) + except Exception as e: + print(f"Error: {e!r}") + else: + return fs def get_data_and_stats( dataId: dict, arcLines: ArcLineSet, detectorMap: DetectorMap, + visitInfo: VisitInfo, adjustDM_config=None, log=None, ) -> tuple[pd.DataFrame, pd.DataFrame]: - good_lines_idx = getGoodLines(arcLines, detectorMap.getDispersionAtCenter(), adjustDM_config, log) + + is_science = visitInfo.observationReason == "science" + + good_lines_idx = getGoodLines( + arcLines, + dispersion=detectorMap.getDispersionAtCenter(), + isScience=is_science, + adjustDMConfig=adjustDM_config, + log=log, + ) arcLines = arcLines[good_lines_idx].copy() arc_data = scrub_data(arcLines, detectorMap, dropNaColumns=True, log=log) @@ -306,7 +306,7 @@ def maskOutliers(grp): log.info("Masking outliers") with warnings.catch_warnings(): warnings.simplefilter("ignore") - arc_data = arc_data.groupby(["status_type", "isLine"]).apply(maskOutliers) + arc_data = arc_data.groupby(["status_type", "description"]).apply(maskOutliers) arc_data.reset_index(drop=True, inplace=True) log.info("Adding fiber information") @@ -322,25 +322,21 @@ def maskOutliers(grp): "(isLine == True and yResidOutlier == False) or (isTrace == True and xResidOutlier == False)" ).copy() - descriptions = sorted(list(arc_data.description.unique())) - with suppress(ValueError): - if len(descriptions) > 1: - descriptions.remove("Trace") - arc_data["arm"] = dataId["arm"] arc_data["spectrograph"] = dataId["spectrograph"] arc_data["visit"] = dataId["visit"] log.info("Getting residual stats") stats = list() - for idx, rows in arc_data.groupby("status_type"): + for (status_type, description), rows in arc_data.groupby(["status_type", "description"]): visit_stats = pd.json_normalize(get_fit_stats(rows).to_dict()) - visit_stats["status_type"] = idx + visit_stats["status_type"] = status_type + visit_stats["description"] = description visit_stats["arm"] = dataId["arm"] visit_stats["spectrograph"] = dataId["spectrograph"] visit_stats["visit"] = dataId["visit"] visit_stats["ccd"] = "{arm}{spectrograph}".format(**dataId) - visit_stats["description"] = ",".join(descriptions) + visit_stats["observationReason"] = visitInfo.observationReason stats.append(visit_stats) stats = pd.concat(stats) @@ -349,7 +345,11 @@ def maskOutliers(grp): def getGoodLines( - lines: ArcLineSet, dispersion: float | None, adjustDMConfig: Config, log: Logger | None = None + lines: ArcLineSet, + dispersion: float | None, + adjustDMConfig: Config, + isScience: bool = False, + log: Logger | None = None, ) -> np.ndarray: """Get the good lines. @@ -362,6 +362,8 @@ def getGoodLines( The dispersion. Default is None. adjustDMConfig : `Config` Configuration used for the detector map adjustment. + isScience : `bool`, optional + Is this a science visit? Default is False. log : `Logger`, optional The logger for the class object. Default is None. @@ -371,14 +373,17 @@ def getGoodLines( The index of the good lines. """ log.debug(f"Scrubbing data using config={adjustDMConfig.toDict()}") - isTrace = lines.description == "Trace" - isLine = ~isTrace - numTraceLines = len(set(lines[isTrace].fiberId)) - numArcLines = len(set(lines[isLine].fiberId)) + traceIndex = lines.description == "Trace" + lineIndex = ~traceIndex + numTraceLines = len(set(lines[traceIndex].fiberId)) + numArcLines = len(set(lines[lineIndex].fiberId)) - log.debug(f"{isTrace.sum()} line centroids for {numTraceLines} traces") - log.debug(f"{isLine.sum()} line centroids for {numArcLines} traces") - log.debug(f"{isLine.sum() + isTrace.sum()} lines in list") + isTrace = lineIndex.sum() == 0 + isArc = isTrace == False + + log.debug(f"{traceIndex.sum()} line centroids for {numTraceLines} traces") + log.debug(f"{lineIndex.sum()} line centroids for {numArcLines} traces") + log.debug(f"{lineIndex.sum() + traceIndex.sum()} lines in list") def getCounts(): """Provide a list of counts of different species""" @@ -387,9 +392,9 @@ def getCounts(): good = lines.flag == 0 log.debug(f"{good.sum()} good lines after initial flags ({getCounts()})") - if isLine.sum() > 0: + if not isScience and isArc: log.info("Found lamp species, ignoring traces.") - good &= isLine + good &= lineIndex log.debug(f"{good.sum()} good lines after ignoring traces ({getCounts()})") good &= (lines.status & ReferenceLineStatus.fromNames(*adjustDMConfig.lineFlags)) == 0 @@ -399,7 +404,7 @@ def getCounts(): good &= np.isfinite(lines.xErr) & np.isfinite(lines.yErr) if hasattr(lines, "slope"): - good &= np.isfinite(lines.slope) | ~isTrace + good &= np.isfinite(lines.slope) | ~traceIndex log.debug(f"{good.sum()} good lines after finite positions ({getCounts()})") if adjustDMConfig.minSignalToNoise > 0: @@ -420,15 +425,15 @@ def getCounts(): if adjustDMConfig.maxCentroidError > 0: maxCentroidError = adjustDMConfig.maxCentroidError good &= (lines.xErr > 0) & (lines.xErr < maxCentroidError) - good &= ((lines.yErr > 0) & (lines.yErr < maxCentroidError)) | isTrace + good &= ((lines.yErr > 0) & (lines.yErr < maxCentroidError)) | traceIndex log.debug(f"{good.sum()} good lines after {maxCentroidError=} centroid errors ({getCounts()})") - if dispersion is not None and adjustDMConfig.exclusionRadius > 0 and not np.all(isTrace): - wavelength = np.unique(lines.wavelength[~isTrace]) + if dispersion is not None and adjustDMConfig.exclusionRadius > 0 and not np.all(traceIndex): + wavelength = np.unique(lines.wavelength[~traceIndex]) status = [np.bitwise_or.reduce(lines.status[lines.wavelength == wl]) for wl in wavelength] exclusionRadius = dispersion * adjustDMConfig.exclusionRadius exclude = getExclusionZone(wavelength, exclusionRadius, np.array(status)) - good &= np.isin(lines.wavelength, wavelength[exclude], invert=True) | isTrace + good &= np.isin(lines.wavelength, wavelength[exclude], invert=True) | traceIndex log.debug(f"{good.sum()} good lines after {exclusionRadius=:.03f} exclusion zone ({getCounts()})") return good @@ -567,10 +572,13 @@ def get_fit_stats( lines = arc_data.query("isLine == True").dropna(subset=["yResid"]).copy() xNum = len(arc_data) + numTraces = traces.fiberId.nunique() try: yNum = lines.isLine.value_counts()[True] + numLines = lines.wavelength.nunique() except KeyError: yNum = 0 + numLines = 0 xWeightedRms = getWeightedRMS(arc_data.xResid, arc_data.xErr, soften=xSoften) yWeightedRms = getWeightedRMS(lines.yResid, lines.yErr, soften=ySoften) @@ -617,8 +625,8 @@ def getSoften(resid, err, dof, soften=0): xFibers = len(traces.fiberId.unique()) yFibers = len(lines.fiberId.unique()) - xFitStat = FitStat(arc_data.xResid.median(), xRobustRms, xWeightedRms, xSoftFit, xDof, xFibers, xNum) - yFitStat = FitStat(lines.yResid.median(), yRobustRms, yWeightedRms, ySoftFit, yDof, yFibers, yNum) + xFitStat = FitStat(arc_data.xResid.median(), xRobustRms, xWeightedRms, xSoftFit, xDof, xFibers, numTraces) + yFitStat = FitStat(lines.yResid.median(), yRobustRms, yWeightedRms, ySoftFit, yDof, yFibers, numLines) return FitStats(dof, chi2X, chi2Y, xFitStat, yFitStat) @@ -672,10 +680,15 @@ def plot_detectormap_residuals( try: for sub_fig, column in zip([x_fig, y_fig], ["xResid", "yResid"]): + if column == "xResid": + plot_stats = visit_stats.query("description == 'Trace'") + else: + plot_stats = visit_stats.query("description != 'Trace'") + try: plot_residual( arc_data, - visit_stats, + plot_stats, column=column, dataRange=spatialRange if column == "xResid" else wavelengthRange, binWavelength=binWavelength, @@ -790,10 +803,8 @@ def plot_residual( plotData.rename(columns={f"{column}Outlier": "isOutlier"}, inplace=True) units = "pix" - which_data = "spatial" if column.startswith("y"): plotData = plotData.query("isTrace == False").copy() - which_data = "wavelength" else: plotData = plotData.query("isTrace == True").copy() diff --git a/python/pfs/drp/qa/utils/plotting.py b/python/pfs/drp/qa/utils/plotting.py index 50c6de10..5bedacfb 100644 --- a/python/pfs/drp/qa/utils/plotting.py +++ b/python/pfs/drp/qa/utils/plotting.py @@ -8,20 +8,17 @@ div_palette = plt.cm.RdBu_r.with_extremes(over="magenta", under="cyan", bad="lime") detector_palette = {"b": "tab:blue", "r": "tab:red", "n": "tab:orange", "m": "tab:pink"} description_palette = { - "Trace": "#ED0A3F", - "ArI": "tab:orange", - "CdI,HgI": "tab:purple", - "HgI": "tab:purple", - "KrI": "tab:brown", - "NeI": "tab:pink", - "XeI": "tab:olive", - "O2,OH": "tab:blue", - "O2,OH,OI": "tab:blue", - "OH,OI": "tab:blue", - "OH": "tab:blue", - "OI": "tab:blue", - "NaI,OI": "tab:blue", - "NaI,OH,OI": "tab:blue", + "Trace": "black", + "ArI": "tab:blue", + "CdI": "tab:orange", + "HgI": "tab:green", + "KrI": "tab:red", + "NeI": "tab:purple", + "XeI": "tab:brown", + "O2": "tab:pink", + "OH": "tab:gray", + "OI": "tab:olive", + "NaI": "tab:cyan", } spectrograph_plot_markers = {1: "s", 2: "o", 3: "X", 4: "P"}