diff --git a/python/pfs/drp/qa/dmCombinedResiduals.py b/python/pfs/drp/qa/dmCombinedResiduals.py index 3c3ad912..e5c3f6b0 100644 --- a/python/pfs/drp/qa/dmCombinedResiduals.py +++ b/python/pfs/drp/qa/dmCombinedResiduals.py @@ -18,9 +18,7 @@ Input as InputConnection, Output as OutputConnection, ) -from matplotlib import pyplot as plt from matplotlib.figure import Figure -from pandas import DataFrame from pfs.drp.stella import DetectorMap from pfs.drp.qa.dmResiduals import plot_detectormap_residuals @@ -122,8 +120,8 @@ def runQuantum( def run( self, detectorMaps: Iterable[DetectorMap], - dmQaResidualData: Iterable[DataFrame], - dmQaResidualStats: Iterable[DataFrame], + dmQaResidualData: Iterable[pd.DataFrame], + dmQaResidualStats: Iterable[pd.DataFrame], run_name: str, ) -> Struct: """Create detector level residual_stats and plots. @@ -177,8 +175,8 @@ def get_ccd(dm: DetectorMap) -> str: def make_report( - residual_stats: DataFrame, - arc_data: DataFrame, + residual_stats: pd.DataFrame, + residual_data: pd.DataFrame, detectorMaps: Dict[str, DetectorMap], run_name: str, log: object, @@ -190,9 +188,6 @@ def make_report( # Add the title as a figure. pdf.append(plot_title(run_name)) - # Add the table data as a figure. - pdf.append(plot_dataframe(reserved_stats)) - # Detector summaries. log.info("Making detector summary plots") pdf.append(plot_detector_summary(reserved_stats)) @@ -224,19 +219,19 @@ def make_report( # Add the 2D residual plot. arm = ccd[0] spec = int(ccd[1]) - data = arc_data.query(f"arm == '{arm}' and spectrograph == {spec}") + plot_data = residual_data.query(f"arm == '{arm}' and spectrograph == {spec}") # If we are doing a combined report we want to get the mean across visits. - grouped = data[plot_cols].groupby(["status", "isLine", "fiberId", "y"]) - data = grouped.mean().reset_index() + grouped = plot_data[plot_cols].groupby(["status", "isLine", "fiberId", "y"]) + plot_data = grouped.mean().reset_index() - residFig = plot_detectormap_residuals(data, visit_stats, detectorMaps[str(ccd)]) + residFig = plot_detectormap_residuals(plot_data, visit_stats, detectorMaps[str(ccd)]) residFig.suptitle(f"DetectorMap Residuals - Median of all visits - {ccd}", weight="bold") pdf.append(residFig, dpi=150) # Add the description per visit breakdown. - fig = plot_detector_visits(visit_stats) - fig.suptitle(f"{fig.get_suptitle()}\n{ccd}") + fig = plot_visits(visit_stats.query('status_type == "RESERVED"'), palette=description_palette) + fig.suptitle(f"{fig.get_suptitle()} - {ccd}") pdf.append(fig) except KeyError: log.warning(f"DetectorMap not found for {ccd}. Skipping.") @@ -247,52 +242,37 @@ def make_report( return pdf -def plot_detector_visits(plot_data: DataFrame) -> Figure: - summary_stats = plot_data.filter(regex="median|weighted").mean().to_dict() - - fig = plot_visits(plot_data, palette=description_palette) - - for ax, dim in zip(fig.axes, ["spatial", "wavelength"]): - upper_range = summary_stats[f"{dim}.median"] + summary_stats[f"{dim}.weightedRms"] - lower_range = summary_stats[f"{dim}.median"] - summary_stats[f"{dim}.weightedRms"] - - ax.axvline(summary_stats[f"{dim}.median"], c="k", ls="--") - ax.axvline(upper_range, c="g", ls="--") - ax.axvline(lower_range, c="g", ls="--") - ax.set_title( - f"{dim.upper()}: " - f'median={summary_stats[f"{dim}.median"]:5.04f} ' - f'rms={summary_stats[f"{dim}.weightedRms"]:5.04f}' - ) - - fig.set_size_inches(8, 8) - - return fig - - -def plot_detector_summary(stats: DataFrame) -> Figure: +def plot_detector_summary(stats: pd.DataFrame) -> Figure: plot_data_spatial = ( stats.query("description == 'Trace'") - .filter(regex="ccd|median|weighted|soften") + .filter(regex="ccd|spatial.(median|weighted|soften)") .groupby("ccd", observed=False) .mean() ) + plot_data_spatial.columns = [c.replace("spatial.", "") for c in plot_data_spatial.columns] plot_data_wavelength = ( stats.query("description != 'Trace'") - .filter(regex="ccd|median|weighted|soften") + .filter(regex="ccd|wavelength.(median|weighted|soften)") .groupby("ccd", observed=False) .mean() ) + plot_data_wavelength.columns = [c.replace("wavelength.", "") for c in plot_data_wavelength.columns] - fig, (ax0, ax1) = plt.subplots(ncols=2, sharey=True, layout="constrained") - fig.set_size_inches(12, 4) + fig = Figure(figsize=(11, 8), layout="constrained") + spatial_plot_ax = fig.add_subplot(221) + wavelength_plot_ax = fig.add_subplot(222, sharey=spatial_plot_ax) + spatial_table_ax = fig.add_subplot(223, sharex=spatial_plot_ax) + wavelength_table_ax = fig.add_subplot(224, sharex=wavelength_plot_ax, sharey=spatial_table_ax) + formatter = "{:5.04f}".format + + # Plot the spatial median and weightedRms and show table below. for ccd, row in plot_data_spatial.iterrows(): - ax0.errorbar( + spatial_plot_ax.errorbar( x=ccd, - y=row["spatial.median"], - yerr=row["spatial.weightedRms"], - markersize=max(row["spatial.softenFit"].mean() * 100, 1), + y=row["median"], + yerr=row["weightedRms"], + markersize=max(row["softenFit"].mean() * 100, 1), marker="o", mec="k", linewidth=2, @@ -300,12 +280,16 @@ def plot_detector_summary(stats: DataFrame) -> Figure: color=detector_palette[ccd[0]], ) + spatial_table = pd.plotting.table(spatial_table_ax, plot_data_spatial.map(formatter), loc="center") + spatial_table.set_fontsize(11) + spatial_table_ax.set_title("Spatial median and weightedRms error (quartz only)") + for ccd, row in plot_data_wavelength.iterrows(): - ax1.errorbar( + wavelength_plot_ax.errorbar( x=ccd, - y=row["wavelength.median"], - yerr=row["wavelength.weightedRms"], - markersize=max(row["wavelength.softenFit"].mean() * 100, 1), + y=row["median"], + yerr=row["weightedRms"], + markersize=max(row["softenFit"].mean() * 100, 1), marker="o", markeredgecolor="k", linewidth=2, @@ -313,20 +297,31 @@ def plot_detector_summary(stats: DataFrame) -> Figure: color=detector_palette[ccd[0]], ) - ax0.axhline(0, c="k", ls="--", alpha=0.3) - ax0.set_title("Spatial median and weightedRms error (quartz only)") - ax0.set_ylabel("Median (pixel)") + wavelength_table = pd.plotting.table( + wavelength_table_ax, plot_data_wavelength.map(formatter), loc="center" + ) + wavelength_table.set_fontsize(11) + wavelength_table_ax.set_title("Wavelength median and weightedRms") + + # Set the titles and labels + spatial_plot_ax.axhline(0, c="k", ls="--", alpha=0.3) + spatial_plot_ax.set_title("Spatial median and weightedRms error (quartz only)") + spatial_plot_ax.set_ylabel("Median (pixel)") - ax1.axhline(0, c="k", ls="--", alpha=0.3) - ax1.set_title("Wavelength median and weightedRms") + wavelength_plot_ax.axhline(0, c="k", ls="--", alpha=0.3) + wavelength_plot_ax.set_title("Wavelength median and weightedRms") - ax0.grid(True, color="k", linestyle="--", alpha=0.15) - ax1.grid(True, color="k", linestyle="--", alpha=0.15) + spatial_plot_ax.grid(True, color="k", linestyle="--", alpha=0.15) + wavelength_plot_ax.grid(True, color="k", linestyle="--", alpha=0.15) + spatial_table_ax.set_axis_off() + wavelength_table_ax.set_axis_off() + + fig.suptitle("DetectorMap Residuals Summary", y=1.05) return fig -def plot_detector_summary_per_desc(stats: DataFrame) -> Figure: +def plot_detector_summary_per_desc(stats: pd.DataFrame) -> Figure: plot_data = ( stats.set_index(["ccd", "description"]) .filter(regex="median|weighted|soften") @@ -356,7 +351,7 @@ def plot_detector_summary_per_desc(stats: DataFrame) -> Figure: aspect=2.5, flierprops={"marker": ".", "ms": 2}, ) - fg.fig.suptitle("DetectorMap Residuals by description", y=1) + fg.figure.suptitle("DetectorMap Residuals by description", y=1) for i, ax in enumerate(fg.figure.axes): ax.set_ylabel("Median residual (pixel)") if i == 0: @@ -385,12 +380,12 @@ def plot_detector_summary_per_desc(stats: DataFrame) -> Figure: alpha=0.5, ) - return fg.fig + return fg.figure def plot_visits( plotData: pd.DataFrame, - palette: Optional[dict] = None, + palette: Optional[dict | list] = None, spatialRange: float = 0.1, wavelengthRange: float = 0.1, fig: Optional[Figure] = None, @@ -419,34 +414,81 @@ def plot_visits( """ plotData = plotData.copy() fig = fig or Figure(layout="constrained") + fig.set_size_inches(11, 8) ax0 = fig.add_subplot(121) ax1 = fig.add_subplot(122, sharex=ax0, sharey=ax0) plotData["visit_idx"] = plotData.visit.rank(method="first") + palette = palette or description_palette + for ax, metric in zip([ax0, ax1], ["spatial", "wavelength"]): - for desc, grp in plotData.groupby("description"): - grp.plot.scatter( - y="visit_idx", - x=f"{metric}.median", - xerr=f"{metric}.weightedRms", + metricData = plotData.copy() + if metric == "spatial": + metricData = metricData.query("description == 'Trace'") + else: + metricData = metricData.query("description != 'Trace'") + + for desc, grp in metricData.groupby("description"): + grpPlotData = grp.copy() + ax.errorbar( + y=grpPlotData["visit_idx"], + x=grpPlotData[f"{metric}.median"], + xerr=grpPlotData[f"{metric}.weightedRms"], marker="o", - color=palette.get(desc, "red") if palette is not None else None, + ms=10, + elinewidth=2, + capsize=4, + mec="w", + ls="", + color=palette.get(desc, "black"), label=desc, - ax=ax, + zorder=110, ) - ax.grid(alpha=0.2) - ax.axvline(0, c="k", ls="--", alpha=0.5) + # Mark the median and 1-sigma range across the visits. + summary_stats = metricData.filter(regex="median|weighted").median().to_dict() + upper_range = summary_stats[f"{metric}.median"] + summary_stats[f"{metric}.weightedRms"] + lower_range = summary_stats[f"{metric}.median"] - summary_stats[f"{metric}.weightedRms"] + + ax.axvline(summary_stats[f"{metric}.median"], c="k", ls="--", zorder=101) + ax.axvline(upper_range, c="g", ls="--", zorder=101) + ax.axvline(lower_range, c="g", ls="--", zorder=101) + ax.set_title( + f"{metric.upper()}: " + f'median={summary_stats[f"{metric}.median"]:5.04f} ' + f'rms={summary_stats[f"{metric}.weightedRms"]:5.04f}' + ) + + ax.grid(which="major", color="k", axis="y", zorder=-100) + ax.axvline(0, c="k", ls="-", alpha=0.75) ax.set_title(f"{metric}") ax.set_xlabel("pix") + + leg = ax.legend(loc="upper right", shadow=True) + leg.set_zorder(1000) + if spatialRange is not None and metric == "spatial": ax.set_xlim(-spatialRange, spatialRange) if wavelengthRange is not None and metric == "wavelength": ax.set_xlim(-wavelengthRange, wavelengthRange) - visit_label = [f"{row.visit}" for idx, row in plotData.iterrows()] - ax0.set_yticks(plotData.visit_idx, visit_label, fontsize="xx-small") + # Only label a visit the first time it's seen. + labeled_ticks = set() + all_ticks = list() + visit_idx = list() + for idx, row in plotData.reset_index().iterrows(): + if row.visit not in labeled_ticks: + all_ticks.append(f"{row.visit}") + labeled_ticks.add(row.visit) + visit_idx.append(idx + 1) + + # Create a striped background to offset the visits. + ax0.set_yticks(visit_idx, labeled_ticks, fontsize="xx-small") + for i, (y0, y1) in enumerate(itertools.pairwise(visit_idx)): + ax0.axhspan(y0, y1, color="whitesmoke" if i % 2 == 0 else "ivory", alpha=0.5) + ax1.axhspan(y0, y1, color="whitesmoke" if i % 2 == 0 else "ivory", alpha=0.5) + ax0.set_ylabel("Visit") ax0.invert_yaxis() @@ -455,57 +497,9 @@ def plot_visits( return fig -def plot_dataframe(stats: DataFrame) -> Figure: - """Plot the residual data frame. - - Parameters - ---------- - stats : `pandas.DataFrame` - The data frame to plot. - - Returns - ------- - fig : `Figure` - The figure. - """ - plot_data_spatial = ( - stats.query("description == 'Trace'") - .filter(regex="ccd|spatial.(median|weighted|soften)") - .groupby("ccd", observed=False) - .mean() - ) - plot_data_spatial.columns = [c.replace("spatial.", "") for c in plot_data_spatial.columns] - plot_data_wavelength = ( - stats.query("description != 'Trace'") - .filter(regex="ccd|wavelength.(median|weighted|soften)") - .groupby("ccd", observed=False) - .mean() - ) - plot_data_wavelength.columns = [c.replace("wavelength.", "") for c in plot_data_wavelength.columns] - - formatter = "{:5.04f}".format - - fig = Figure(layout="constrained") - ax0 = fig.add_subplot(211) - ax1 = fig.add_subplot(212) - ax0.set_axis_off() - ax1.set_axis_off() - t0 = pd.plotting.table(ax0, plot_data_spatial.map(formatter), loc="center") - t0.set_fontsize(11) - t1 = pd.plotting.table(ax1, plot_data_wavelength.map(formatter), loc="center") - t1.set_fontsize(11) - - ax0.set_title("Spatial (quartz only)", y=1.12) - ax1.set_title("Wavelength", y=1.12) - - fig.suptitle("Residuals summary", y=1.15) - - return fig - - def plot_title(run_name: str) -> Figure: """Plot a title page for the combined report.""" - fig = Figure() + fig = Figure(figsize=(11, 8)) ax = fig.add_subplot(111) ax.set_axis_off() ax.text(0.5, 0.5, "DetectorMap Residuals Summary", ha="center", va="center", fontsize="large") diff --git a/python/pfs/drp/qa/dmResiduals.py b/python/pfs/drp/qa/dmResiduals.py index e998271c..a378a847 100644 --- a/python/pfs/drp/qa/dmResiduals.py +++ b/python/pfs/drp/qa/dmResiduals.py @@ -1,5 +1,4 @@ import warnings -from contextlib import suppress from dataclasses import dataclass from functools import partial from logging import Logger @@ -8,6 +7,7 @@ import numpy as np import pandas as pd from astropy.stats import sigma_clip +from lsst.afw.image import VisitInfo from lsst.pex.config import Config, Field from lsst.pipe.base import ( InputQuantizedConnection, @@ -38,37 +38,29 @@ class DetectorMapResidualsConnections( PipelineTaskConnections, - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ): """Connections for DetectorMapQaTask""" + visitInfo = InputConnection( + name="raw.visitInfo", + doc="Visit info from the raw exposure", + storageClass="VisitInfo", + dimensions=("instrument", "visit", "arm", "spectrograph"), + ) + detectorMap = InputConnection( name="detectorMap", doc="Adjusted detector mapping from fiberId,wavelength to x,y", storageClass="DetectorMap", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) arcLines = InputConnection( name="lines", doc="Emission line measurements", storageClass="ArcLineSet", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) reduceExposure_config = InputConnection( name="reduceExposure_config", @@ -76,38 +68,24 @@ class DetectorMapResidualsConnections( storageClass="Config", dimensions=(), ) + dmQaResidualData = OutputConnection( name="dmQaResidualData", doc="The dataframe of the detectormap residuals.", storageClass="DataFrame", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) dmQaResidualStats = OutputConnection( name="dmQaResidualStats", doc="Statistics of the DM residual analysis.", storageClass="DataFrame", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) dmQaResidualPlot = OutputConnection( name="dmQaResidualPlot", doc="The 1D and 2D residual plots of the detectormap with the arclines for a given visit.", storageClass="Plot", - dimensions=( - "instrument", - "visit", - "arm", - "spectrograph", - ), + dimensions=("instrument", "visit", "arm", "spectrograph"), ) @@ -159,10 +137,11 @@ def run( self, arcLines: ArcLineSet, detectorMap: DetectorMap, + visitInfo: VisitInfo, + dataId: dict, dropNaColumns: bool = True, removeOutliers: bool = True, addFiberInfo: bool = True, - dataId: dict = None, reduceExposure_config: Config = None, **kwargs, ) -> Struct: @@ -178,14 +157,17 @@ def run( The arc lines. detectorMap : `DetectorMap` The detector map. + visitInfo : `VisitInfo` + The visit info containing the observationReason, which determines + some plotting parameters. + dataId : `dict` + The dataId for the visit. dropNaColumns : `bool`, optional Drop columns where all values are NaN. Default is True. removeOutliers : `bool`, optional Remove rows with ``flag=False``? Default is True. addFiberInfo : `bool`, optional Add fiber information to the dataframe. Default is True. - dataId : dict, optional - Dictionary of the dataId. reduceExposure_config : `Config`, optional Configuration for reduceExposure. @@ -202,6 +184,7 @@ def run( dataId, arcLines, detectorMap, + visitInfo, adjustDM_config=adjustDM_config, log=self.log, ) @@ -266,30 +249,47 @@ def to_dict(self): @classmethod def from_dataframe(cls, df: pd.DataFrame): """Convert from dataframe to FitStats.""" - reserved_wl = df.filter(like="wavelength.").copy() - reserved_spatial = df.filter(like="spatial.").copy() - - reserved_wl.columns = reserved_wl.columns.str.rsplit(".", n=1).str[-1] - reserved_spatial.columns = reserved_spatial.columns.str.rsplit(".", n=1).str[-1] - - rec = df.iloc[0] - return cls( - dof=rec.dof, - chi2X=rec.chi2X, - chi2Y=rec.chi2Y, - spatial=FitStat(*reserved_spatial.iloc[0].to_list()), - wavelength=FitStat(*reserved_wl.iloc[0].to_list()), - ) + try: + df = df.select_dtypes(include="number").median().to_frame().T + + reserved_wl = df.filter(like="wavelength.").copy() + reserved_spatial = df.filter(like="spatial.").copy() + + reserved_wl.columns = reserved_wl.columns.str.rsplit(".", n=1).str[-1] + reserved_spatial.columns = reserved_spatial.columns.str.rsplit(".", n=1).str[-1] + + rec = df.iloc[0] + fs = cls( + dof=rec.dof, + chi2X=rec.chi2X, + chi2Y=rec.chi2Y, + spatial=FitStat(*reserved_spatial.iloc[0].to_list()), + wavelength=FitStat(*reserved_wl.iloc[0].to_list()), + ) + except Exception as e: + print(f"Error: {e!r}") + else: + return fs def get_data_and_stats( dataId: dict, arcLines: ArcLineSet, detectorMap: DetectorMap, + visitInfo: VisitInfo, adjustDM_config=None, log=None, ) -> tuple[pd.DataFrame, pd.DataFrame]: - good_lines_idx = getGoodLines(arcLines, detectorMap.getDispersionAtCenter(), adjustDM_config, log) + + is_science = visitInfo.observationReason == "science" + + good_lines_idx = getGoodLines( + arcLines, + dispersion=detectorMap.getDispersionAtCenter(), + isScience=is_science, + adjustDMConfig=adjustDM_config, + log=log, + ) arcLines = arcLines[good_lines_idx].copy() arc_data = scrub_data(arcLines, detectorMap, dropNaColumns=True, log=log) @@ -306,7 +306,7 @@ def maskOutliers(grp): log.info("Masking outliers") with warnings.catch_warnings(): warnings.simplefilter("ignore") - arc_data = arc_data.groupby(["status_type", "isLine"]).apply(maskOutliers) + arc_data = arc_data.groupby(["status_type", "description"]).apply(maskOutliers) arc_data.reset_index(drop=True, inplace=True) log.info("Adding fiber information") @@ -322,25 +322,21 @@ def maskOutliers(grp): "(isLine == True and yResidOutlier == False) or (isTrace == True and xResidOutlier == False)" ).copy() - descriptions = sorted(list(arc_data.description.unique())) - with suppress(ValueError): - if len(descriptions) > 1: - descriptions.remove("Trace") - arc_data["arm"] = dataId["arm"] arc_data["spectrograph"] = dataId["spectrograph"] arc_data["visit"] = dataId["visit"] log.info("Getting residual stats") stats = list() - for idx, rows in arc_data.groupby("status_type"): + for (status_type, description), rows in arc_data.groupby(["status_type", "description"]): visit_stats = pd.json_normalize(get_fit_stats(rows).to_dict()) - visit_stats["status_type"] = idx + visit_stats["status_type"] = status_type + visit_stats["description"] = description visit_stats["arm"] = dataId["arm"] visit_stats["spectrograph"] = dataId["spectrograph"] visit_stats["visit"] = dataId["visit"] visit_stats["ccd"] = "{arm}{spectrograph}".format(**dataId) - visit_stats["description"] = ",".join(descriptions) + visit_stats["observationReason"] = visitInfo.observationReason stats.append(visit_stats) stats = pd.concat(stats) @@ -349,7 +345,11 @@ def maskOutliers(grp): def getGoodLines( - lines: ArcLineSet, dispersion: float | None, adjustDMConfig: Config, log: Logger | None = None + lines: ArcLineSet, + dispersion: float | None, + adjustDMConfig: Config, + isScience: bool = False, + log: Logger | None = None, ) -> np.ndarray: """Get the good lines. @@ -362,6 +362,8 @@ def getGoodLines( The dispersion. Default is None. adjustDMConfig : `Config` Configuration used for the detector map adjustment. + isScience : `bool`, optional + Is this a science visit? Default is False. log : `Logger`, optional The logger for the class object. Default is None. @@ -371,14 +373,17 @@ def getGoodLines( The index of the good lines. """ log.debug(f"Scrubbing data using config={adjustDMConfig.toDict()}") - isTrace = lines.description == "Trace" - isLine = ~isTrace - numTraceLines = len(set(lines[isTrace].fiberId)) - numArcLines = len(set(lines[isLine].fiberId)) + traceIndex = lines.description == "Trace" + lineIndex = ~traceIndex + numTraceLines = len(set(lines[traceIndex].fiberId)) + numArcLines = len(set(lines[lineIndex].fiberId)) - log.debug(f"{isTrace.sum()} line centroids for {numTraceLines} traces") - log.debug(f"{isLine.sum()} line centroids for {numArcLines} traces") - log.debug(f"{isLine.sum() + isTrace.sum()} lines in list") + isTrace = lineIndex.sum() == 0 + isArc = isTrace == False + + log.debug(f"{traceIndex.sum()} line centroids for {numTraceLines} traces") + log.debug(f"{lineIndex.sum()} line centroids for {numArcLines} traces") + log.debug(f"{lineIndex.sum() + traceIndex.sum()} lines in list") def getCounts(): """Provide a list of counts of different species""" @@ -387,9 +392,9 @@ def getCounts(): good = lines.flag == 0 log.debug(f"{good.sum()} good lines after initial flags ({getCounts()})") - if isLine.sum() > 0: + if not isScience and isArc: log.info("Found lamp species, ignoring traces.") - good &= isLine + good &= lineIndex log.debug(f"{good.sum()} good lines after ignoring traces ({getCounts()})") good &= (lines.status & ReferenceLineStatus.fromNames(*adjustDMConfig.lineFlags)) == 0 @@ -399,7 +404,7 @@ def getCounts(): good &= np.isfinite(lines.xErr) & np.isfinite(lines.yErr) if hasattr(lines, "slope"): - good &= np.isfinite(lines.slope) | ~isTrace + good &= np.isfinite(lines.slope) | ~traceIndex log.debug(f"{good.sum()} good lines after finite positions ({getCounts()})") if adjustDMConfig.minSignalToNoise > 0: @@ -420,15 +425,15 @@ def getCounts(): if adjustDMConfig.maxCentroidError > 0: maxCentroidError = adjustDMConfig.maxCentroidError good &= (lines.xErr > 0) & (lines.xErr < maxCentroidError) - good &= ((lines.yErr > 0) & (lines.yErr < maxCentroidError)) | isTrace + good &= ((lines.yErr > 0) & (lines.yErr < maxCentroidError)) | traceIndex log.debug(f"{good.sum()} good lines after {maxCentroidError=} centroid errors ({getCounts()})") - if dispersion is not None and adjustDMConfig.exclusionRadius > 0 and not np.all(isTrace): - wavelength = np.unique(lines.wavelength[~isTrace]) + if dispersion is not None and adjustDMConfig.exclusionRadius > 0 and not np.all(traceIndex): + wavelength = np.unique(lines.wavelength[~traceIndex]) status = [np.bitwise_or.reduce(lines.status[lines.wavelength == wl]) for wl in wavelength] exclusionRadius = dispersion * adjustDMConfig.exclusionRadius exclude = getExclusionZone(wavelength, exclusionRadius, np.array(status)) - good &= np.isin(lines.wavelength, wavelength[exclude], invert=True) | isTrace + good &= np.isin(lines.wavelength, wavelength[exclude], invert=True) | traceIndex log.debug(f"{good.sum()} good lines after {exclusionRadius=:.03f} exclusion zone ({getCounts()})") return good @@ -567,10 +572,13 @@ def get_fit_stats( lines = arc_data.query("isLine == True").dropna(subset=["yResid"]).copy() xNum = len(arc_data) + numTraces = traces.fiberId.nunique() try: yNum = lines.isLine.value_counts()[True] + numLines = lines.wavelength.nunique() except KeyError: yNum = 0 + numLines = 0 xWeightedRms = getWeightedRMS(arc_data.xResid, arc_data.xErr, soften=xSoften) yWeightedRms = getWeightedRMS(lines.yResid, lines.yErr, soften=ySoften) @@ -617,8 +625,8 @@ def getSoften(resid, err, dof, soften=0): xFibers = len(traces.fiberId.unique()) yFibers = len(lines.fiberId.unique()) - xFitStat = FitStat(arc_data.xResid.median(), xRobustRms, xWeightedRms, xSoftFit, xDof, xFibers, xNum) - yFitStat = FitStat(lines.yResid.median(), yRobustRms, yWeightedRms, ySoftFit, yDof, yFibers, yNum) + xFitStat = FitStat(arc_data.xResid.median(), xRobustRms, xWeightedRms, xSoftFit, xDof, xFibers, numTraces) + yFitStat = FitStat(lines.yResid.median(), yRobustRms, yWeightedRms, ySoftFit, yDof, yFibers, numLines) return FitStats(dof, chi2X, chi2Y, xFitStat, yFitStat) @@ -672,10 +680,15 @@ def plot_detectormap_residuals( try: for sub_fig, column in zip([x_fig, y_fig], ["xResid", "yResid"]): + if column == "xResid": + plot_stats = visit_stats.query("description == 'Trace'") + else: + plot_stats = visit_stats.query("description != 'Trace'") + try: plot_residual( arc_data, - visit_stats, + plot_stats, column=column, dataRange=spatialRange if column == "xResid" else wavelengthRange, binWavelength=binWavelength, @@ -790,10 +803,8 @@ def plot_residual( plotData.rename(columns={f"{column}Outlier": "isOutlier"}, inplace=True) units = "pix" - which_data = "spatial" if column.startswith("y"): plotData = plotData.query("isTrace == False").copy() - which_data = "wavelength" else: plotData = plotData.query("isTrace == True").copy() diff --git a/python/pfs/drp/qa/utils/plotting.py b/python/pfs/drp/qa/utils/plotting.py index 50c6de10..5bedacfb 100644 --- a/python/pfs/drp/qa/utils/plotting.py +++ b/python/pfs/drp/qa/utils/plotting.py @@ -8,20 +8,17 @@ div_palette = plt.cm.RdBu_r.with_extremes(over="magenta", under="cyan", bad="lime") detector_palette = {"b": "tab:blue", "r": "tab:red", "n": "tab:orange", "m": "tab:pink"} description_palette = { - "Trace": "#ED0A3F", - "ArI": "tab:orange", - "CdI,HgI": "tab:purple", - "HgI": "tab:purple", - "KrI": "tab:brown", - "NeI": "tab:pink", - "XeI": "tab:olive", - "O2,OH": "tab:blue", - "O2,OH,OI": "tab:blue", - "OH,OI": "tab:blue", - "OH": "tab:blue", - "OI": "tab:blue", - "NaI,OI": "tab:blue", - "NaI,OH,OI": "tab:blue", + "Trace": "black", + "ArI": "tab:blue", + "CdI": "tab:orange", + "HgI": "tab:green", + "KrI": "tab:red", + "NeI": "tab:purple", + "XeI": "tab:brown", + "O2": "tab:pink", + "OH": "tab:gray", + "OI": "tab:olive", + "NaI": "tab:cyan", } spectrograph_plot_markers = {1: "s", 2: "o", 3: "X", 4: "P"}