diff --git a/amstrax/__init__.py b/amstrax/__init__.py index 5e914074..37448992 100644 --- a/amstrax/__init__.py +++ b/amstrax/__init__.py @@ -1,11 +1,11 @@ -__version__ = '1.2.0' - -from . import plugins -from .plugins import * +__version__ = "1.2.1" from .common import * from .rundb import * +from . import plugins +from .plugins import * + from . import contexts from . import hitfinder_thresholds @@ -13,3 +13,8 @@ from . import auto_processing +from .mini_analysis import * + +from .matplotlib_utils import * + +from . import analyses diff --git a/amstrax/analyses/__init__.py b/amstrax/analyses/__init__.py new file mode 100644 index 00000000..8931f05f --- /dev/null +++ b/amstrax/analyses/__init__.py @@ -0,0 +1,4 @@ +from . import plotting +from . import waveform_plot +from . import records_matrix +from . import quick_checks diff --git a/amstrax/analyses/plotting.py b/amstrax/analyses/plotting.py new file mode 100644 index 00000000..0a122882 --- /dev/null +++ b/amstrax/analyses/plotting.py @@ -0,0 +1,457 @@ +import matplotlib.pyplot as plt +import numpy as np +import strax +import amstrax # Import your amstrax module +import matplotlib as mpl +import pandas as pd +import datetime +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import collections +from matplotlib.lines import Line2D +import warnings + +export, __all__ = strax.exporter() + +@amstrax.mini_analysis(requires=("raw_records","records")) +def plot_records(context, run_id, raw_records, records, raw=True, logy=False, **kwargs): + """ + Plot raw records + + :param context: strax.Context provided by the mini-analysis wrapper + :param run_id: Run ID of the data + :param kwargs: Additional keyword arguments + :return: None + """ + + if raw: + records = raw_records + else: + records = records + + # Create subplots for each channel + n_channels = max(records["channel"]) + 1 + + fig, axes = plt.subplots( + n_channels, 1, figsize=(8, 1.5 * (n_channels + 1)), sharex=True + ) + # Plot each channel's raw record + for i, ax in enumerate(axes): + _records = records[records["channel"] == i] + # For every record, plot the data (y) and time*dt (x) + for record in _records: + # Define timestamps for every sample of the record + # Start from record time, then every sample is dt long + # and add a full lenght*dt for every record_i + time = np.linspace(0, record["length"] - 1, record["length"]) * record["dt"] + time += record["time"] - records[0]["time"] + if raw: + data = -record["data"][: record["length"]] + else: + data = record["data"][: record["length"]] + + + # Plot the data + ax.plot(time, data) + # if logy, set the yscale to log + if logy: + ax.set_yscale("log") + + ax.set_ylabel("ADC Counts") + + # add text on the upper left corner of the plot + ax.text( + 0.05, + 0.95, + f"Channel {i}", + horizontalalignment="left", + verticalalignment="top", + transform=ax.transAxes, + ) + # ax.set_title(f"Channel {i}") + + # remove space between subplots in the figure + fig.subplots_adjust(wspace=0, hspace=0.1) + + axes[-1].set_ylabel("ADC Counts") + axes[-1].legend() + + # Set the x-axis label for the last subplot + axes[-1].set_xlabel("Time since start of peak [ns]") + + # Set the title + # Put the title above the subplots, not inside the frame + plt.subplots_adjust(top=0.95) + plt.suptitle(f"raw_records {records[0]['time']} from Run ID: {run_id}") + + # Show the plot + plt.show() + + +# Export the function +export(plot_records) + + +@amstrax.mini_analysis(requires=("peaks", "raw_records")) +def plot_peak_records(context, run_id, raw_records, peaks, **kwargs): + """ + Plot raw records of a peak and the total waveform. + + :param context: strax.Context provided by the mini-analysis wrapper + :param run_id: Run ID of the data + :param peaks: Peaks for which to plot raw records + :param kwargs: Additional keyword arguments + :return: None + """ + + # Get the raw records for the channels in the peak + # using strax.touching_windows + + records = raw_records + + if len(peaks) != 1: + raise ValueError( + "The time range you specified contains more or" + " less than a single event. The event display " + " only works with individual events for now." + ) + + peak = peaks[0] + + # Create subplots for each channel + n_channels = max(records["channel"]) + 1 + + fig, axes = plt.subplots( + n_channels + 1, 1, figsize=(8, 1.5 * (n_channels + 1)), sharex=True + ) + + # Plot each channel's raw record + for i, ax in enumerate(axes[:-1]): + _records = records[records["channel"] == i] + # For every record, plot the data (y) and time*dt (x) + for record in _records: + # Define timestamps for every sample of the record + # Start from record time, then every sample is dt long + # and add a full lenght*dt for every record_i + + time = np.linspace(0, record["length"] - 1, record["length"]) * record["dt"] + time += record["time"] - peak["time"] + data = -record["data"][: record["length"]] + + last_length = record["length"] + + # Plot the data + ax.plot(time, data) + plt.subplots_adjust(wspace=None, hspace=None) + fig.subplots_adjust(0, 0, 1, 1, 0, 0) + + plt.grid() + + ax.set_ylabel("ADC Counts") + + # add text on the upper left corner of the plot + ax.text( + 0.05, + 0.95, + f"Channel {i}", + horizontalalignment="left", + verticalalignment="top", + transform=ax.transAxes, + ) + # ax.set_title(f"Channel {i}") + + # Plot the total waveform (sum of all channels) + + # Make peak times, give the data and the dt + time = np.linspace(0, peak["length"] - 1, peak["length"]) * peak["dt"] + data = peak["data"][: peak["length"]] + + # remove space between subplots in the figure + fig.subplots_adjust(wspace=0, hspace=0.1) + + axes[-1].plot( + time, + data, + label="Total Waveform", + color="black", + ) + + axes[-1].set_ylabel("ADC Counts") + axes[-1].legend() + + plt.grid() + + # Set the x-axis label for the last subplot + axes[-1].set_xlabel("Time since start of peak [ns]") + + # Set the title + # Put the title above the subplots, not inside the frame + plt.subplots_adjust(top=0.95) + plt.suptitle(f"Peak time: {peak['time']}, Run ID: {run_id}") + + # Show the plot + plt.show() + + +# Export the function +export(plot_peak_records) + + +# Let's make a function to plot the area per channel +# We need two panels, one next to each other +# Left panel with 4 quadrants, for channels 1 2 3 4 +# Right panel with 1 quadrants, for channels 0 +@amstrax.mini_analysis(requires=("peaks",)) +def plot_area_per_channel(context, run_id, peaks, **kwargs): + fig, axes = plt.subplots(1, 2, figsize=(8, 4)) + + top_quadrant_length = 1 + bottom_quadrant_length = 2 + + axes = axes.flatten() + + if len(peaks) != 1: + raise ValueError( + "The time range you specified contains more or" + " less than a single event. The event display " + " only works with individual events for now." + ) + + peak = peaks[0] + + # Get the peaks + area_per_channel = peak["area_per_channel"] + + # Plot the area per channel + # Four quadrants for channels 1 2 3 4 + # imshow with 4 quadrants + + axes[0].imshow( + area_per_channel[1:5].reshape(2, 2), + cmap="viridis", + extent=[ + -top_quadrant_length, + top_quadrant_length, + -top_quadrant_length, + top_quadrant_length, + ], + origin="upper", + ) + axes[1].imshow( + area_per_channel[0].reshape(1, 1), + cmap="viridis", + extent=[ + -bottom_quadrant_length, + bottom_quadrant_length, + -bottom_quadrant_length, + bottom_quadrant_length, + ], + ) + + # Write the number of the channel in every element of imshow of axes[0] + # inside a little frame such that it is more readable + + for i in range(2): + for j in range(2): + ch = i * 2 + j + 1 + x = j - top_quadrant_length / 2 + y = i - top_quadrant_length / 2 + + t = axes[0].text( + x, + y, + # format nicely in scientific notation with x10^ + f"Ch {ch} \n {area_per_channel[ch]:.1e} PE", + horizontalalignment="center", + verticalalignment="center", + color="black", + ) + + t.set_bbox(dict(facecolor="white", alpha=0.5, edgecolor="black")) + + # do the same in channel 0 + t = axes[1].text( + 0, + 0, + f"Ch 0 \n {area_per_channel[0]:.1e} PE", + horizontalalignment="center", + verticalalignment="center", + color="black", + ) + + t.set_bbox(dict(facecolor="white", alpha=0.5, edgecolor="black")) + + # Set the title + # Put the title above the subplots, not inside the frame + plt.subplots_adjust(top=0.95) + plt.suptitle(f"Peak time: {peak['time']}, Run ID: {run_id}") + + # Add a common colorbar for both axes + # It needs to consider the values in both axes + # Create an axis for the colorbar + cax = fig.add_axes([0.95, 0.15, 0.02, 0.7]) + # Create the colorbar + cb = plt.colorbar( + axes[0].images[0], + cax=cax, + ) + + # Set the label of the colorbar + cb.set_label("Area per channel [PE]") + + +export(plot_area_per_channel) + + +@amstrax.mini_analysis(requires=("records_led",)) +def plot_led_records(context, run_id, records_led, n_records=100, **kwargs): + db = amstrax.get_mongo_collection() + + rd = db.find_one({"number": int(run_id)}) + + print(rd["mode"]) + + if not rd: + raise ValueError(f"Run {run_id} not found in the database.") + + if "ext_trig" not in rd["mode"]: + raise ValueError( + "This run doesn't look like an external trigger run, so better to avoid this plot." + ) + + st = context + + # Get the records + records_led = records_led[0:n_records] + + records_led_config = st.get_single_plugin(run_id, "records_led").config + led_calibration_config = st.get_single_plugin(run_id, "led_calibration").config + + # Create a figure and axis + fig, ax = plt.subplots(figsize=(18, 6)) + colors = [k for i, k in mpl.colors.TABLEAU_COLORS.items()] + + # Plot the data + for r in records_led[:200]: + ax.plot(r["data"][0 : r["length"]], alpha=0.5, c=colors[int(r["channel"])]) + + # Create patches for each color + lines = [ + Line2D([0], [0], color=color, lw=2, label=f"Channel {i}") + for i, color in enumerate(colors[:5]) + ] + + # Add the legend + legend = ax.legend(handles=lines, frameon=True, edgecolor="black") + + # Shade regions between vertical lines and add labels + def shade_and_label(x1, x2, color, label): + ax.axvspan(x1, x2, color=color, alpha=0.2) + ax.text( + (x1 + x2) / 2, + ax.get_ylim()[1] * 0.95, + label, + ha="center", + va="top", + color=color, + ) + + shade_and_label( + records_led_config["baseline_window"][0], + records_led_config["baseline_window"][1], + "b", + "Baseline", + ) + shade_and_label( + led_calibration_config["led_window"][0], + led_calibration_config["led_window"][1], + "r", + "LED Signal", + ) + shade_and_label( + led_calibration_config["noise_window"][0], + led_calibration_config["noise_window"][1], + "g", + "Noise", + ) + + # Add vertical lines + ax.axvline(records_led_config["record_length"], linestyle="--", alpha=0.5, c="k") + + ax.axvline( + records_led_config["baseline_window"][0], linestyle="--", alpha=0.5, c="b" + ) + ax.axvline( + records_led_config["baseline_window"][1], linestyle="--", alpha=0.5, c="b" + ) + + ax.axvline( + led_calibration_config["led_window"][0], linestyle="--", alpha=0.5, c="r" + ) + ax.axvline( + led_calibration_config["led_window"][1], linestyle="--", alpha=0.5, c="r" + ) + + ax.axvline( + led_calibration_config["noise_window"][0], linestyle="--", alpha=0.5, c="g" + ) + ax.axvline( + led_calibration_config["noise_window"][1], linestyle="--", alpha=0.5, c="g" + ) + + # Add grid, labels, and title + ax.grid(alpha=0.4, which="both") + ax.set_xlabel("Samples") + ax.set_title("Data Visualization with Shaded Backgrounds") + + # Display the plot + plt.show() + + +export(plot_led_records) + + +@amstrax.mini_analysis(requires=("led_calibration",)) +def plot_led_areas(context, run_id, led_calibration, **kwargs): + # Get led calibration + + db = amstrax.get_mongo_collection() + + rd = db.find_one({"number": int(run_id)}) + + print(rd["mode"]) + + if not rd: + raise ValueError(f"Run {run_id} not found in the database.") + + if "ext_trig" not in rd["mode"]: + raise ValueError( + "This run doesn't look like an external trigger run, so better to avoid this plot." + ) + + fig, ax = plt.subplots(figsize=(15, 5)) + + colors = [k for i, k in mpl.colors.TABLEAU_COLORS.items()] + + # Loop through each channel and plot its histogram + for i in range(5): + channel_data = led_calibration[led_calibration["channel"] == i]["area"] + ax.hist( + channel_data, + bins=250, + histtype="step", + label=f"Channel {i}", + color=colors[i], + ) + + ax.set_yscale("log") + ax.legend(loc="upper right") + ax.set_title("Histogram for Every Channel") + ax.set_xlabel("Area") + ax.set_ylabel("Frequency (log scale)") + ax.grid(alpha=0.4, which="both") + + plt.show() + + +export(plot_led_areas) diff --git a/amstrax/analyses/quick_checks.py b/amstrax/analyses/quick_checks.py new file mode 100644 index 00000000..0b4ea616 --- /dev/null +++ b/amstrax/analyses/quick_checks.py @@ -0,0 +1,277 @@ +# from https://github.com/XENONnT/amstrax/blob/master/amstrax/analyses/quick_checks.py + +import matplotlib.pyplot as plt +import numpy as np +import amstrax +from multihist import Hist1d, Histdd +from matplotlib.colors import LogNorm + + +@amstrax.mini_analysis(requires=('peak_basics',)) +def plot_peaks_aft_histogram( + context, run_id, peaks, + pe_bins=np.logspace(0, 7, 120), + rt_bins=np.geomspace(2, 1e5, 120), + extra_labels=tuple(), + rate_range=(1e-4, 1), + aft_range=(0, .85), + figsize=(14, 5)): + """Plot side-by-side (area, width) histograms of the peak rate + and mean area fraction top. + + :param pe_bins: Array of bin edges for the peak area dimension [PE] + :param rt_bins: array of bin edges for the rise time dimension [ns] + :param extra_labels: List of (area, risetime, text, color) extra labels + to put on the plot + :param rate_range: Range of rates to show [peaks/(bin*s)] + :param aft_range: Range of mean S1 area fraction top / bin to show + :param figsize: Figure size to use + """ + livetime_sec = get_livetime_sec(context, run_id, peaks) + + mh = Histdd(peaks, + dimensions=( + ('area', pe_bins), + ('range_50p_area', rt_bins), + ('area_fraction_top', np.linspace(0, 1, 100)), + + )) + + f, axes = plt.subplots(1, 2, figsize=figsize) + + def std_axes(): + plt.gca().set_facecolor('k') + plt.yscale('log') + plt.xscale('log') + plt.xlabel("Area [PE]") + plt.ylabel("Range 50% area [ns]") + labels = [ + (12, 8, "AP?", 'white'), + (3, 150, "1PE\npileup", 'gray'), + + (30, 200, "1e", 'gray'), + (100, 1000, "n-e", 'w'), + (2000, 2e4, "Train", 'gray'), + + (1200, 50, "S1", 'w'), + (45e3, 60, "αS1", 'w'), + + (2e5, 800, "S2", 'w'), + ] + list(extra_labels) + + for x, w, text, color in labels: + plt.text(x, w, text, color=color, + verticalalignment='center', + horizontalalignment='center') + + plt.sca(axes[0]) + (mh / livetime_sec).sum(axis=2).plot( + norm=LogNorm(vmin=rate_range[0], vmax=rate_range[1]), + colorbar_kwargs=dict(extend='both'), + cblabel='Peaks / (bin * s)') + std_axes() + + plt.sca(axes[1]) + mh.average(axis=2).plot( + vmin=aft_range[0], vmax=aft_range[1], + colorbar_kwargs=dict(extend='max'), + cmap=plt.cm.jet, + cblabel='Mean area fraction top') + + std_axes() + plt.tight_layout() + +@amstrax.mini_analysis(requires=['event_info']) +def event_scatter(context, run_id, events, + show_single=True, + s=10, + color_range=(None, None), + color_dim='s1_area_fraction_top', + figsize=(7, 5)): + """Plot a (cS1, cS2) event scatter plot + + :param show_single: Show events with only S1s or only S2s just besides + the axes. + :param s: Scatter size + :param color_dim: Dimension to use for the color. Must be in event_info. + :param color_range: Minimum and maximum color value to show. + :param figsize: (w, h) figure size to use, or leave None to not make a + new matplotlib figure. + """ + if figsize is not None: + plt.figure(figsize=figsize) + if color_dim == 's1_area_fraction_top' and color_range == (None, None): + color_range = (0, 0.3) + + plt.scatter(np.nan_to_num(events['cs1']).clip(.9, None), + np.nan_to_num(events['cs2']).clip(.9, None), + clip_on=not show_single, + c=events[color_dim], + vmin=color_range[0], vmax=color_range[1], + s=s, + cmap=plt.cm.jet, + marker='.', edgecolors='none') + + plt.xlabel('cS1 [PE]') + plt.xscale('log') + plt.xlim(1, None) + + plt.ylabel('cS2 [PE]') + plt.yscale('log') + plt.ylim(1, None) + + p = context.get_single_plugin(run_id, 'energy_estimates') + ax = plt.gca() + el_lim = p.cs1_to_e(np.asarray(ax.get_xlim())) + ec_lim = p.cs2_to_e(np.asarray(ax.get_ylim())) + + ax2 = ax.twiny() + ax2.set_xlim(*el_lim) + ax2.set_xscale('log') + ax2.set_xlabel("E_light [keVee]") + + ax3 = ax2.twinx() + ax3.set_ylim(*ec_lim) + ax3.set_yscale('log') + ax3.set_ylabel("E_charge [keVee]") + + plt.sca(ax3) + plt.plot(el_lim, el_lim, c='k', alpha=0.2) + x = np.geomspace(*el_lim, num=1000) + e_label = 1.2e-3 + for e_const, label in [ + (0.1, ''), (1, '1\nkeV'), (10, '10\nkeV'), + (100, '100\nkeV'), (1e3, '1\nMeV'), (1e4, '')]: + plt.plot(x, e_const - x, c='k', alpha=0.2) + plt.text(e_const - e_label, e_label, label, + bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'), + horizontalalignment='center', verticalalignment='center', + color='k', alpha=0.5) + + plt.sca(ax) + if color_range[0] is None: + extend = 'neither' if color_range[1] is None else 'max' + else: + extend = 'min' if color_range[1] is None else 'both' + if color_dim == 's1_area_fraction_top': + plt.colorbar(label="S1 area fraction top", + extend=extend, + ax=[ax, ax3]) + else: + plt.colorbar(label=color_dim, + extend=extend, + ax=[ax, ax3]) + + +@amstrax.mini_analysis(requires=('event_info',)) +def plot_energy_spectrum( + events, + color='b', label=None, + unit=None, exposure_kg_sec=None, + error_alpha=0.5, errors='fc', + n_bins=100, min_energy=1, max_energy=100, geomspace=True): + """Plot an energy spectrum histogram, with 1 sigma + Poisson confidence intervals around it. + + :param exposure_kg_sec: Exposure in kg * sec + :param unit: Unit to plot spectrum in. Can be either: + - events (events per bin) + - kg_day_kev (events per kg day keV) + - tonne_day_kev (events per tonne day keV) + - tonne_year_kev (events per tonne year keV) + Defaults to kg_day_kev if exposure_kg_sec is provided, + otherwise events. + + :param min_energy: Minimum energy of the histogram + :param max_energy: Maximum energy of the histogram + :param geomspace: If True, will use a logarithmic energy binning. + Otherwise will use a linear scale. + :param n_bins: Number of energy bins to use + + :param color: Color to plot in + :param label: Label for the line + :param error_alpha: Alpha value for the statistical error band + :param errors: Type of errors to draw, passed to 'errors' + argument of Hist1d.plot. + """ + if unit is None: + if exposure_kg_sec is not None: + unit = 'kg_day_kev' + else: + unit = 'events' + + h = Hist1d(events['e_ces'], + bins=(np.geomspace if geomspace else np.linspace)( + min_energy, max_energy, n_bins)) + + if unit == 'events': + scale, ylabel = 1, 'Events per bin' + else: + if exposure_kg_sec is None: + raise ValueError('you did not specify exposure_kg_sec') + exposure_kg_day = exposure_kg_sec / (3600 * 24) + if unit == 'kg_day_kev': + scale = exposure_kg_day + ylabel = 'Events / (kg day keV)' + elif unit == 'tonne_day_kev': + scale = exposure_kg_day / 1000 + ylabel = 'Events / (tonne day keV)' + elif unit == 'tonne_year_kev': + scale = exposure_kg_day / 1000 + ylabel = 'Events / (tonne year keV)' + else: + raise ValueError(f"Invalid unit {unit}") + scale *= h.bin_volumes() + + h.plot(errors=errors, + error_style='band', + color=color, + label=label, + linewidth=1, + scale_histogram_by=1 / scale, + error_alpha=error_alpha) + plt.yscale('log') + if geomspace: + amstrax.log_x(min_energy, max_energy, scalar_ticks=True) + else: + plt.xlim(min_energy, max_energy) + plt.ylabel(ylabel) + plt.xlabel("Energy [keV_ee], CES") + + +@amstrax.mini_analysis(requires=('peak_basics',)) +def plot_peak_classification(peaks, s=1): + """Make an (area, rise_time) scatter plot of peaks + :param s: Size of dot for each peak + """ + for cl, color in enumerate('kbg'): + d = peaks[peaks['type'] == cl] + plt.scatter(d['area'], d['rise_time'], c=color, + s=s, marker='.', edgecolors='none', + label={0: 'Unknown', 1: 'S1', 2: 'S2'}[cl]) + plt.legend(loc='lower right', markerscale=10) + + plt.xscale('log') + plt.yscale('log') + plt.xlim(1, 2e6) + plt.ylim(3, 1e4) + plt.xlabel("Area [PE]") + plt.ylabel("Rise time [ns]") + + +def get_livetime_sec(context, run_id, things=None): + """Get the livetime of a run in seconds. If it is not in the run metadata, + estimate it from the data-level metadata of the data things. + """ + try: + md = context.run_metadata(run_id, + projection=('start', 'end', 'livetime')) + except strax.RunMetadataNotAvailable: + if things is None: + raise + return (strax.endtime(things[-1]) - things[0]['time']) / 1e9 + else: + if 'livetime' in md: + return md['livetime'] + else: + return (md['end'] - md['start']).total_seconds() diff --git a/amstrax/analyses/records_matrix.py b/amstrax/analyses/records_matrix.py new file mode 100644 index 00000000..b6b43000 --- /dev/null +++ b/amstrax/analyses/records_matrix.py @@ -0,0 +1,159 @@ +# From https://github.com/XENONnT/amstrax/blob/master/amstrax/analyses/records_matrix.py + +import warnings + +import numba +import numpy as np +import strax +import amstrax + +DEFAULT_MAX_SAMPLES = 20_000 + + +@amstrax.mini_analysis(requires=('records',), + warn_beyond_sec=10, + default_time_selection='touching') +def records_matrix(records, time_range, seconds_range, config, to_pe, + max_samples=DEFAULT_MAX_SAMPLES, + ignore_max_sample_warning=False): + """Return (wv_matrix, times, pms) + - wv_matrix: (n_samples, n_pmt) array with per-PMT waveform intensity in PE/ns + - times: time labels in seconds (corr. to rows) + - pmts: PMT numbers (corr. to columns) + Both times and pmts have one extra element. + + :param max_samples: Maximum number of time samples. If window and dt + conspire to exceed this, waveforms will be downsampled. + :param ignore_max_sample_warning: If True, suppress warning when this happens. + + Example: + wvm, ts, ys = st.records_matrix(run_id, seconds_range=(1., 1.00001)) + plt.pcolormesh(ts, ys, wvm.T, + norm=matplotlib.colors.LogNorm()) + plt.colorbar(label='Intensity [PE / ns]') + """ + if len(records): + dt = records[0]['dt'] + samples_per_record = len(records[0]['data']) + else: + # Defaults here do not matter, nothing will be plotted anyway + dt = 10, 110 + record_duration = samples_per_record * dt + + window = time_range[1] - time_range[0] + if window / dt > max_samples: + with np.errstate(divide='ignore', invalid='ignore'): + # Downsample. New dt must be + # a) multiple of old dt + dts = np.arange(0, record_duration + dt, dt).astype(np.int64) + # b) divisor of record duration + dts = dts[record_duration / dts % 1 == 0] + # c) total samples < max_samples + dts = dts[window / dts < max_samples] + if len(dts): + # Pick lowest dt that satisfies criteria + dt = dts.min() + else: + # Records will be downsampled to single points + dt = max(record_duration, window // max_samples) + if not ignore_max_sample_warning: + warnings.warn(f"Matrix would exceed max_samples {max_samples}, " + f"downsampling to dt = {dt} ns.") + + wvm = _records_to_matrix( + records, + t0=time_range[0], + n_channels=config['n_tpc_pmts'], + dt=dt, + window=window) + + try: + wvm = wvm.astype(np.float32) * to_pe.reshape(1, -1) / dt + except: + wvm = wvm.astype(np.float32) * 1 / dt + + + # Note + 1, so data for sample 0 will range from 0-1 in plot + ts = (np.arange(wvm.shape[0] + 1) * dt / int(1e9) + seconds_range[0]) + ys = np.arange(wvm.shape[1] + 1) + + return wvm, ts, ys + + +@amstrax.mini_analysis(requires=('raw_records',), + warn_beyond_sec=3e-3, + default_time_selection='touching') +def raw_records_matrix(context, run_id, raw_records, time_range, + ignore_max_sample_warning=False, + max_samples=DEFAULT_MAX_SAMPLES, + **kwargs): + # Convert raw to records. We may not be able to baseline correctly + # at the start of the range due to missing zeroth fragments + records = strax.raw_to_records(raw_records) + strax.baseline(records, allow_sloppy_chunking=True) + strax.zero_out_of_bounds(records) + + return context.records_matrix(run_id=run_id, + records=records, + time_range=time_range, + max_samples=max_samples, + ignore_max_sample_warning=ignore_max_sample_warning, + **kwargs) + + +def _records_to_matrix(records, t0, window, n_channels, dt=10): + if np.any(records['amplitude_bit_shift'] > 0): + warnings.warn('Ignoring amplitude bitshift!') + return _records_to_matrix_inner(records, t0, window, n_channels, dt) + + +@numba.njit +def _records_to_matrix_inner(records, t0, window, n_channels, dt=10): + n_samples = (window // dt) + 1 + # Use 32-bit integers, so downsampling saturated samples doesn't + # cause wraparounds + y = np.zeros((n_samples, n_channels), + dtype=np.int32) + + if not len(records): + return y + samples_per_record = len(records[0]['data']) + + for r in records: + if r['channel'] > n_channels: + continue + + if dt >= samples_per_record * r['dt']: + # Downsample to single sample -> store area + idx = (r['time'] - t0) // dt + if idx >= len(y): + print(len(y), idx) + raise IndexError('Despite n_samples = window // dt + 1, our ' + 'idx is too high?!') + y[idx, r['channel']] += r['area'] + continue + + # Assume out-of-bounds data has been zeroed, so we do not + # need to do r['data'][:r['length']] here. + # This simplifies downsampling. + w = r['data'].astype(np.int32) + + if dt > r['dt']: + # Downsample + duration = samples_per_record * r['dt'] + if duration % dt != 0: + raise ValueError("Cannot downsample fractionally") + # .astype here keeps numba happy ... ?? + w = w.reshape(duration // dt, -1).sum(axis=1).astype(np.int32) + + elif dt < r['dt']: + raise ValueError("Upsampling not yet implemented") + + (r_start, r_end), (y_start, y_end) = strax.overlap_indices( + r['time'] // dt, len(w), + t0 // dt, n_samples) + # += is paranoid, data in individual channels should not overlap + # but... https://github.com/AxFoundation/strax/issues/119 + y[y_start:y_end, r['channel']] += w[r_start:r_end] + + return y diff --git a/amstrax/analyses/waveform_plot.py b/amstrax/analyses/waveform_plot.py new file mode 100644 index 00000000..8821b62a --- /dev/null +++ b/amstrax/analyses/waveform_plot.py @@ -0,0 +1,314 @@ +# From https://github.com/XENONnT/amstrax/blob/master/amstrax/analyses/waveform_plot.py + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import strax +import amstrax +from mpl_toolkits.axes_grid1 import inset_locator + + +DEFAULT_MAX_SAMPLES = 20_000 + +export, __all__ = strax.exporter() +__all__ += ["plot_wf"] + + +@amstrax.mini_analysis() +def plot_waveform( + context, + deep=False, + show_largest=100, + figsize=None, + max_samples=DEFAULT_MAX_SAMPLES, + ignore_max_sample_warning=True, + cbar_loc="lower right", + lower_panel_height=2, + **kwargs, +): + """Plot the sum waveform and optionally per-PMT waveforms + + :param deep: If True, show per-PMT waveform matrix under sum waveform. + If 'raw', use raw_records instead of records to do so. + :param show_largest: Show only the largest show_largest peaks. + :param figsize: Matplotlib figure size for the plot + + Additional options for deep = True or raw: + :param cbar_loc: location of the intensity color bar. Set to None + to omit it altogether. + :param lower_panel_height: Height of the lower panel in terms of + the height of the upper panel. + """ + if figsize is None: + figsize = (10, 6 if deep else 4) + + if not deep: + context.plot_peaks(**kwargs, show_largest=show_largest, figsize=figsize) + + else: + f, axes = plt.subplots( + 2, + 1, + constrained_layout=True, + figsize=figsize, + gridspec_kw={"height_ratios": [1, lower_panel_height]}, + ) + + plt.sca(axes[0]) + context.plot_peaks( + **kwargs, show_largest=show_largest, single_figure=False, xaxis=False + ) + + plt.sca(axes[1]) + context.plot_records_matrix( + **kwargs, + cbar_loc=cbar_loc, + max_samples=max_samples, + ignore_max_sample_warning=ignore_max_sample_warning, + raw=deep == "raw", + single_figure=False, + ) + + plt.subplots_adjust(hspace=0) + + +@amstrax.mini_analysis( + requires=("peaks", "peak_basics"), + default_time_selection="touching", + warn_beyond_sec=60, +) +def plot_peaks( + peaks, + seconds_range, + t_reference, + show_largest=100, + single_figure=True, + figsize=(10, 4), + xaxis=True, +): + if single_figure: + plt.figure(figsize=figsize) + plt.axhline(0, c="k", alpha=0.2) + + peaks = peaks[np.argsort(-peaks["area"])[:show_largest]] + peaks = strax.sort_by_time(peaks) + + for p in peaks: + plot_peak(p, t0=t_reference, color={0: "gray", 1: "b", 2: "g"}[p["type"]]) + + if xaxis == "since_start": + seconds_range_xaxis(seconds_range, t0=seconds_range[0]) + elif xaxis: + seconds_range_xaxis(seconds_range) + plt.xlim(*seconds_range) + else: + plt.xticks([]) + plt.xlim(*seconds_range) + plt.ylabel("Intensity [PE/ns]") + if single_figure: + plt.tight_layout() + + +@amstrax.mini_analysis( + requires=("peaks", "peak_basics"), + default_time_selection="touching", + warn_beyond_sec=60, +) +def plot_hit_pattern( + peaks, + seconds_range, + t_reference, + axes=None, + vmin=None, + log_scale=False, + label=None, + single_figure=False, + xenon1t=False, + figsize=(10, 4), +): + if single_figure: + plt.figure(figsize=figsize) + if len(peaks) > 1: + print(f"warning showing total area of {len(peaks)} peaks") + amstrax.plot_pmts( + np.sum(peaks["area_per_channel"], axis=0), + axes=axes, + vmin=vmin, + log_scale=log_scale, + label=label, + xenon1t=xenon1t, + ) + + +@amstrax.mini_analysis() +def plot_records_matrix( + context, + run_id, + seconds_range, + cbar_loc="upper right", + raw=False, + single_figure=True, + figsize=(10, 4), + group_by=None, + max_samples=DEFAULT_MAX_SAMPLES, + ignore_max_sample_warning=False, + vmin=None, + vmax=None, + **kwargs, +): + if seconds_range is None: + raise ValueError( + "You must pass a time selection (e.g. seconds_range) " + "to plot_records_matrix." + ) + + if single_figure: + plt.figure(figsize=figsize, constrained_layout=True) + + f = context.raw_records_matrix if raw else context.records_matrix + + wvm, ts, ys = f( + run_id, + max_samples=max_samples, + ignore_max_sample_warning=ignore_max_sample_warning, + **kwargs, + ) + + + plt.ylabel("Channel number") + + # extract min and max from kwargs or set defaults + if vmin is None: + vmin = min(0.1 * wvm.max(), 1e-2) + if vmax is None: + vmax = wvm.max() + + plt.pcolormesh( + ts, + ys, + wvm.T, + norm=matplotlib.colors.LogNorm( + vmin=vmin, + vmax=vmax, + ), + cmap=plt.cm.inferno, + ) + # plt.xlim(*seconds_range) + + ax = plt.gca() + if group_by is not None: + # Do some magic to convert all the labels to an integer that + # allows for remapping of the y labels to whatever is provided + # in the "ylabs", otherwise matplotlib shows nchannels different + # labels in the case of strings. + # Make a dict that converts the label to an int + int_labels = {h: i for i, h in enumerate(set(ylabs))} + mask = np.ones(len(ylabs), dtype=np.bool_) + # If the label (int) is different wrt. its neighbour, show it + mask[1:] = np.abs(np.diff([int_labels[y] for y in ylabs])) > 0 + # Only label the selection + ax.set_yticks(np.arange(len(ylabs))[mask]) + ax.set_yticklabels(ylabs[mask]) + plt.xlabel("Time [s]") + + if cbar_loc is not None: + # Create a white box to place the color bar in + # See https://stackoverflow.com/questions/18211967 + bbox = inset_locator.inset_axes(ax, width="20%", height="22%", loc=cbar_loc) + _ = [bbox.spines[k].set_visible(False) for k in bbox.spines] + bbox.patch.set_facecolor((1, 1, 1, 0.9)) + bbox.set_xticks([]) + bbox.set_yticks([]) + + # Create the actual color bar + cax = inset_locator.inset_axes( + bbox, width="90%", height="20%", loc="upper center" + ) + plt.colorbar(cax=cax, label="Intensity [PE/ns]", orientation="horizontal") + cax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter("%g")) + + plt.sca(ax) + + +def seconds_range_xaxis(seconds_range, t0=None): + """Make a pretty time axis given seconds_range""" + plt.xlim(*seconds_range) + ax = plt.gca() + ax.ticklabel_format(useOffset=False) + xticks = plt.xticks()[0] + if not len(xticks): + return + + # Format the labels + # I am not very proud of this code... + def chop(x): + return np.floor(x).astype(np.int64) + + if t0 is None: + xticks_ns = np.round(xticks * int(1e9)).astype(np.int64) + else: + xticks_ns = np.round((xticks - xticks[0]) * int(1e9)).astype(np.int64) + sec = chop(xticks_ns // int(1e9)) + ms = chop((xticks_ns % int(1e9)) // int(1e6)) + us = chop((xticks_ns % int(1e6)) // int(1e3)) + samples = chop((xticks_ns % int(1e3)) // 10) + + labels = [str(sec[i]) for i in range(len(xticks))] + print_ns = np.any(samples != samples[0]) + print_us = print_ns | np.any(us != us[0]) + print_ms = print_us | np.any(ms != ms[0]) + if print_ms and t0 is None: + labels = [l + f".{ms[i]:03}" for i, l in enumerate(labels)] + if print_us: + labels = [ + l + r" $\bf{" + f"{us[i]:03}" + "}$" for i, l in enumerate(labels) + ] + if print_ns: + labels = [l + f" {samples[i]:02}0" for i, l in enumerate(labels)] + plt.xticks(ticks=xticks, labels=labels, rotation=90) + else: + labels = list(chop((xticks_ns // 10) * 10)) + labels[-1] = "" + plt.xticks(ticks=xticks, labels=labels, rotation=0) + if t0 is None: + plt.xlabel("Time since run start [sec]") + else: + plt.xlabel("Time [ns]") + + +def plot_peak(p, t0=None, center_time=True, **kwargs): + x, y = time_and_samples(p, t0=t0) + kwargs.setdefault("linewidth", 1) + + # Plot waveform + plt.plot(x, y, drawstyle="steps-pre", **kwargs) + if "linewidth" in kwargs: + del kwargs["linewidth"] + kwargs["alpha"] = kwargs.get("alpha", 1) * 0.2 + plt.fill_between(x, 0, y, step="pre", linewidth=0, **kwargs) + + # Mark extent with thin black line + plt.plot([x[0], x[-1]], [y.max(), y.max()], c="k", alpha=0.3, linewidth=1) + + # Mark center time with thin black line + if center_time: + if t0 is None: + t0 = p["time"] + ct = (p["center_time"] - t0) / int(1e9) + plt.axvline(ct, c="k", alpha=0.4, linewidth=1, linestyle="--") + + +def time_and_samples(p, t0=None): + """Return (x, y) numpy arrays for plotting the waveform data in p + using 'steps-pre'. + Where x is the time since t0 in seconds (or another time_scale), + and y is intensity in PE / ns. + :param p: Peak or other similar strax data type + :param t0: Zero of time in ns since unix epoch + """ + n = p["length"] + if t0 is None: + t0 = p["time"] + x = ((p["time"] - t0) + np.arange(n + 1) * p["dt"]) / int(1e9) + y = p["data"][:n] / p["dt"] + return x, np.concatenate([[y[0]], y]) diff --git a/amstrax/auto_processing/amstraxer_easy.py b/amstrax/auto_processing/amstraxer_easy.py index 76cc462b..e20c7089 100644 --- a/amstrax/auto_processing/amstraxer_easy.py +++ b/amstrax/auto_processing/amstraxer_easy.py @@ -23,7 +23,7 @@ def parse_args(): help="Name of context to use") parser.add_argument( '--target', - default=['raw_records'], + default=['raw_records',], nargs="*", help='Target final data type to produce') parser.add_argument( diff --git a/amstrax/auto_processing/auto_processing.py b/amstrax/auto_processing/auto_processing.py index 7d1627c6..3067220b 100644 --- a/amstrax/auto_processing/auto_processing.py +++ b/amstrax/auto_processing/auto_processing.py @@ -1,6 +1,6 @@ import argparse import time -from datetime import datetime +from datetime import datetime, timedelta import subprocess def parse_args(): @@ -10,7 +10,7 @@ def parse_args(): parser.add_argument( '--target', nargs="*", - default=['peak_basics'], + default=['raw_records',], help="Target final data type to produce.") parser.add_argument( '--output_folder', @@ -66,7 +66,10 @@ def parse_args(): output_folder = args.output_folder process_stomboot = args.process_stomboot detector = args.detector - target = args.target + + targets = args.target + targets = " ".join(targets) + runs_col = amstrax.get_mongo_collection(detector) print('Correctly connected, starting loop') @@ -79,7 +82,7 @@ def parse_args(): run_docs_to_do = list(runs_col.find({ 'processing_status':{'$ne': 'done'}, 'end':{"$ne":None}, - 'start':{'$gt': datetime(2023,1,25)}, + 'start':{'$gt': datetime.today() - timedelta(days=5)}, 'processing_failed':{'$not': {'$gt': 3}}, }).sort('start', -1)) @@ -96,15 +99,14 @@ def parse_args(): run_name = f'{int(run_doc["number"]):06}' if process_stomboot: - submit_stbc.submit_job(run_name, target=target, context=context, detector=detector,script='process_run') - runs_col.find_one_and_update({'number': run_name}, - {'$set': {'processing_status': 'submitted_job' }}) + # submit_stbc.submit_job(run_name, target=target, context=context, detector=detector,script='process_run') + # runs_col.find_one_and_update({'number': run_name}, + # {'$set': {'processing_status': 'submitted_job' }}) + pass else: #process locally - runs_col.find_one_and_update({'number': run_name}, - {'$set': {'processing_status': 'processing'}}) - target = " ".join(target) - subprocess.run(f"process_run {run_name} --target {target} --output_folder {output_folder}", shell=True) + print(f'Processing run {run_name}, target [{targets}]') + subprocess.run(f"process_run {run_name} --target {targets} --output_folder {output_folder}", shell=True) time.sleep(2) if max_jobs is not None and len(run_docs_to_do) > max_jobs: diff --git a/amstrax/common.py b/amstrax/common.py index 966762df..d167aeef 100644 --- a/amstrax/common.py +++ b/amstrax/common.py @@ -17,7 +17,21 @@ import strax export, __all__ = strax.exporter() -__all__ += ['amstrax_dir', 'to_pe'] +__all__ += ['amstrax_dir', + 'to_pe', + 'n_tpc_pmts', + 'n_xamsl_channel', + 'tpc_r', + 'tpc_z', + ] + +# Current values +n_tpc_pmts = 8 +n_xamsl_channel = 4 +to_pe = 1 + +tpc_r = 6 # TODO check this value +tpc_z = 10 # TODO check this value # Current values n_tpc_pmts = 8 diff --git a/amstrax/contexts.py b/amstrax/contexts.py index 4745cbd3..e45375f4 100644 --- a/amstrax/contexts.py +++ b/amstrax/contexts.py @@ -23,12 +23,16 @@ ax.Peaks, ax.PeakClassification, ax.PeakBasics, + # ax.RecordsLED, + # ax.LEDCalibration, + ax.PeakClassification, + ax.PeakProximity, + ax.PeakPositions, ax.Events, - ax.RecordsLED, - ax.LEDCalibration, - # ax.EventBasics, - # ax.EventPositions, # ax.CorrectedAreas, + # ax.EventPositions, + # ax.EventBasics, + # ax.EventInfo, # ax.EnergyEstimates, ], store_run_fields=( @@ -143,6 +147,20 @@ def context_for_daq_reader(st: strax.Context, UserWarning(f'You changed the context for {run_id}. Do not process any other run!') return st +def xams_led(**kwargs): + st = xams(**kwargs) + st.set_context_config( + {'check_available': ('raw_records', 'records_led', 'led_calibration')}) + # Return a new context with only raw_records and led_calibration registered + st = st.new_context( + replace=True, + config=st.config, + storage=st.storage, + **st.context_config) + st.register([ax.DAQReader, + ax.RecordsLED, + ax.LEDCalibration]) + return st def _check_raw_records_exists(st: strax.Context, run_id: str) -> bool: for plugin_name in st._plugin_class_registry.keys(): diff --git a/amstrax/legacy/SiPMpositions.py b/amstrax/legacy/SiPMpositions.py new file mode 100644 index 00000000..56c18caf --- /dev/null +++ b/amstrax/legacy/SiPMpositions.py @@ -0,0 +1,84 @@ +import numba +import numpy as np +import strax +export, __all__ = strax.exporter() + +from amstrax.SiPMdata import * +# move this to legacy once you have the new peak_positions.py in amstrax + +@export +class PeakPositions(strax.Plugin): + depends_on = ('peaks', 'peak_classification') + rechunk_on_save = False + __version__ = '0.0.34' # .33 for LNLIKE + dtype = [ + ('xr', np.float32, + 'Interaction x-position'), + ('yr', np.float32, + 'Interaction y-position'), + ('r', np.float32, + 'radial distance'), + ('time', np.int64, 'Start time of the peak (ns since unix epoch)'), + ('endtime', np.int64, 'End time of the peak (ns since unix epoch)') + ] + + def setup(self): + # z position of the in-plane SiPMs + z_plane = 10 + # radius of the cylinder for SiPMs at the side + r_cylinder = 22 + # radius of a SiPM - I assume circular SiPMs with a radius to make the area correspond to a 3x3mm2 square. + r_sipm = 1.6925 + # build geometry + geo = GeoParameters(z_plane=z_plane, r_cylinder=r_cylinder, r_sipm=r_sipm) + + sipm = SiPM(type="plane", position=[0, -15, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[-13, -7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + # sipm = SiPM(type="plane", position=[0, 15, z_plane], qeff=0.25) + # geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[13, -7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[-4, 0, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[4, 0, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[-13, 7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[13, 7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + + self.geo = geo + + def compute(self, peaks): + + result = np.empty(len(peaks), dtype=self.dtype) + + if not len(peaks): + return result + + for ix, p in enumerate(peaks): + + if p['type'] != 2: + continue + + # if [X] channel is not working + k = np.delete(p['area_per_channel'], [2]) + for i, area in enumerate(k): + self.geo.sipms[i].set_number_of_hits(area) + + # if all 8 channels are working + # for i, area in enumerate(p['area_per_channel']): + # self.geo.sipms[i].set_number_of_hits(area) + + posrec = Reconstruction(self.geo) + pos = posrec.reconstruct_position('CHI2') + for key in ['xr', 'yr']: + result[key][ix] = pos[key] + + for q in ['time', 'endtime']: + result[q] = p[q] + + result['r'] = (result['xr'] ** 2 + result['yr'] ** 2) ** (1 / 2) + return result diff --git a/amstrax/legacy/event_positions.py b/amstrax/legacy/event_positions.py new file mode 100644 index 00000000..eadc2120 --- /dev/null +++ b/amstrax/legacy/event_positions.py @@ -0,0 +1,65 @@ +import numba +import numpy as np +import strax +export, __all__ = strax.exporter() + +@export +class EventPositions(strax.LoopPlugin): + depends_on = ('events', 'event_basics', 'peaks', 'peak_classification') + rechunk_on_save = False + dtype = [ + ('xr', np.float32, + 'Interaction x-position'), + ('yr', np.float32, + 'Interaction y-position'), + ('time', np.int64, 'Event start time in ns since the unix epoch'), + ('endtime', np.int64, 'Event end time in ns since the unix epoch') + ] + __version__ = '0.0.4' + + def setup(self): + # z position of the in-plane SiPMs + z_plane = 10 + # radius of the cyinder for SiPMs at the side + r_cylinder = 22 + # radius of a SiPM - I assume circular SiPMs with a radius to make the area correspond to a 3x3mm2 square. + r_sipm = 1.6925 + # build geometry + geo = GeoParameters(z_plane=z_plane, r_cylinder=r_cylinder, r_sipm=r_sipm) + + sipm = SiPM(type="plane", position=[0, -15, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[-13, -7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[13, -7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[-4, 0, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[4, 0, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[-13, 7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + sipm = SiPM(type="plane", position=[13, 7.5, z_plane], qeff=0.25) + geo.add_sipm(sipm) + + self.geo = geo + + def compute_loop(self, events, peaks): + result = dict() + + if not len(peaks): + return result + + s2_index = events['s2_index'] + if s2_index == -1 or s2_index > len(peaks[(peaks['type'] == 2)]) - 1: + return result + + s2_peak = peaks[(peaks['type'] == 2)][s2_index] + for i, area in enumerate(s2_peak['area_per_channel'][:7]): + self.geo.sipms[i].set_number_of_hits(area) + + posrec = Reconstruction(self.geo) + pos = posrec.reconstruct_position('LNLIKE') + for key in ['xr', 'yr']: + result[key] = pos[key] + return result diff --git a/amstrax/matplotlib_utils.py b/amstrax/matplotlib_utils.py new file mode 100644 index 00000000..d148a407 --- /dev/null +++ b/amstrax/matplotlib_utils.py @@ -0,0 +1,164 @@ +import warnings + +import numpy as np +import matplotlib +import matplotlib.pyplot as plt + +import strax + +export, __all__ = strax.exporter() + + +@export +def plot_pmts( + c, label='', + figsize=None, + xenon1t=False, + show_tpc=True, + extend='neither', vmin=None, vmax=None, + **kwargs): + """Plot the PMT arrays side-by-side, coloring the PMTS with c. + :param c: Array of colors to use. Must have len() n_tpc_pmts + :param label: Label for the color bar + :param figsize: Figure size to use. + :param extend: same as plt.colorbar(extend=...) + :param vmin: Minimum of color scale + :param vmax: maximum of color scale + :param show_axis_labels: if True it will show x and y labels + Other arguments are passed to plot_on_single_pmt_array. + """ + if vmin is None: + vmin = np.nanmin(c) + if vmax is None: + vmax = np.nanmax(c) + if vmin == vmax: + # Single-valued array passed + vmax += 1 + if figsize is None: + figsize = (11.25, 4.25) if xenon1t else (13.25, 5.75) + + f, axes = plt.subplots(1, 2, figsize=figsize) + plot_result = None + for array_i, array_name in enumerate(['top', 'bottom']): + ax = axes[array_i] + plt.sca(ax) + plt.title(array_name.capitalize()) + + plot_result = plot_on_single_pmt_array( + c, + xenon1t=xenon1t, + array_name=array_name, + show_tpc=show_tpc, + vmin=vmin, vmax=vmax, + **kwargs) + + axes[0].set_xlabel('x [cm]') + axes[0].xaxis.set_label_coords(1.035, -0.075) + axes[0].set_ylabel('y [cm]') + + axes[1].yaxis.tick_right() + axes[1].yaxis.set_label_position('right') + + plt.tight_layout() + plt.subplots_adjust(wspace=0) + plt.colorbar(mappable=plot_result, ax=axes, extend=extend, label=label) + + + +@export +def log_y(a=None, b=None, scalar_ticks=True, tick_at=None): + """Make the y axis use a log scale from a to b""" + plt.yscale('log') + if a is not None: + if b is None: + a, b = a[0], a[-1] + ax = plt.gca() + plt.ylim(a, b) + if scalar_ticks: + ax.yaxis.set_major_formatter( + matplotlib.ticker.FormatStrFormatter('%g')) + ax.set_yticks(logticks(a, b, tick_at)) + + +@export +def log_x(a=None, b=None, scalar_ticks=True, tick_at=None): + """Make the x axis use a log scale from a to b""" + plt.xscale('log') + if a is not None: + if b is None: + a, b = a[0], a[-1] + plt.xlim(a, b) + ax = plt.gca() + if scalar_ticks: + ax.xaxis.set_major_formatter( + matplotlib.ticker.FormatStrFormatter('%g')) + ax.set_xticks(logticks(a, b, tick_at)) + + +def logticks(tmin, tmax=None, tick_at=None): + if tick_at is None: + tick_at = (1, 2, 5, 10) + a, b = np.log10([tmin, tmax]) + a = np.floor(a) + b = np.ceil(b) + ticks = np.sort(np.unique(np.outer( + np.array(tick_at), + 10.**np.arange(a, b)).ravel())) + ticks = ticks[(tmin <= ticks) & (ticks <= tmax)] + return ticks + + +@export +def draw_box(x, y, **kwargs): + """Draw rectangle, given x-y boundary tuples""" + plt.gca().add_patch(matplotlib.patches.Rectangle( + (x[0], y[0]), x[1] - x[0], y[1] - y[0], facecolor='none', **kwargs)) + + +@export +def plot_single_pulse(records, run_id, pulse_i=''): + """ + Function which plots a single pulse. + + :param records: Records which belong to the pulse. + :param run_id: Id of the run. + :param pulse_i: Index of the pulse to be plotted. + + :returns: fig, axes objects. + """ + pulse = _make_pulse(records) + + fig, axes = plt.subplots() + sec, ns = _split_off_ns(records[0]['time']) + date = np.datetime_as_string(sec.astype(' warn_beyond_sec: + tr_str = ( + "the entire run" if tr is None else f"{sec_requested} seconds" + ) + raise ValueError( + f"The author of this mini analysis recommends " + f"not requesting more than {warn_beyond_sec} seconds. " + f"You are requesting {tr_str}. If you wish to proceed, " + "pass ignore_time_warning = True." + ) + + # Load required data, if any + if len(requires): + deps_by_kind = strax.group_by_kind(requires, context=context) + for dkind, dtypes in deps_by_kind.items(): + if dkind in kwargs: + print(f"Already have {dkind} data, just apply cuts") + + # Already have data, just apply cuts + kwargs[dkind] = strax.apply_selection( + kwargs[dkind], + selection_str=kwargs["selection_str"], + time_range=kwargs["time_range"], + time_selection=kwargs["time_selection"], + ) + else: + print(f"Loading {dkind} data: {dtypes}") + + kwargs[dkind] = context.get_array( + run_id, + dtypes, + selection_str=kwargs["selection_str"], + time_range=kwargs["time_range"], + time_selection=kwargs["time_selection"], + # Arguments for new context, if needed + config=kwargs.get("config"), + register=kwargs.get("register"), + storage=kwargs.get("storage", tuple()), + progress_bar=False, + ) + + # If the user did not give time kwargs, but the function expects + # a time_range, try to add one based on the time range of the data + base_dkind = list(deps_by_kind.keys())[0] + x = kwargs[base_dkind] + + if len(x) and kwargs.get("time_range") is None: + x0 = x.iloc[0] if isinstance(x, pd.DataFrame) else x[0] + try: + kwargs.setdefault( + "time_range", (x0["time"], strax.endtime(x).max()) + ) + + except AttributeError: + # If x is a holoviews dataset, this will fail. + pass + + if "seconds_range" in parameters: + if kwargs.get("time_range") is None: + scr = None + else: + scr = tuple( + [ + (t - kwargs["t_reference"]) / int(1e9) + for t in kwargs["time_range"] + ] + ) + kwargs.setdefault("seconds_range", scr) + + kwargs.setdefault("run_id", run_id) + kwargs.setdefault("context", context) + + if "kwargs" in parameters: + # Likely, this will be passed to another mini-analysis + to_pass = kwargs + # Do not pass time_range and seconds_range both (unless explicitly requested) + # strax does not like that + if "seconds_range" in to_pass and not "seconds_range" in parameters: + del to_pass["seconds_range"] + if "time_within" in to_pass and not "time_within" in parameters: + del to_pass["time_within"] + else: + # Pass only arguments the function wants + to_pass = {k: v for k, v in kwargs.items() if k in parameters} + + return f(**to_pass) + + wrapped_f.__name__ = f.__name__ + + if hasattr(f, "__doc__") and f.__doc__: + doc_lines = f.__doc__.splitlines() + wrapped_f.__doc__ = ( + doc_lines[0] + "\n" + textwrap.dedent("\n".join(doc_lines[1:])) + ) + else: + wrapped_f.__doc__ = ( + "Amstrax mini-analysis for which someone was too lazy" + "to write a proper docstring" + ) + + wrapped_f.__doc__ += ma_doc.format(requires=", ".join(requires)) + + strax.Context.add_method(wrapped_f) + return wrapped_f + + return decorator diff --git a/amstrax/plugins/__init__.py b/amstrax/plugins/__init__.py index 8c189f39..82a292f1 100644 --- a/amstrax/plugins/__init__.py +++ b/amstrax/plugins/__init__.py @@ -6,8 +6,8 @@ from . import records from .records import * -from . import led_cal -from .led_cal import * +from . import led_calibration +from .led_calibration import * from . import peaks from .peaks import * diff --git a/amstrax/plugins/events/__init__.py b/amstrax/plugins/events/__init__.py index 1287623a..d96f0b01 100644 --- a/amstrax/plugins/events/__init__.py +++ b/amstrax/plugins/events/__init__.py @@ -4,15 +4,15 @@ from . import event_basics from .event_basics import * -from . import event_positions -from .event_positions import * - from . import corrected_areas from .corrected_areas import * from . import energy_estimates from .energy_estimates import * +from . import event_positions +from .event_positions import * + from . import event_info from .event_info import * diff --git a/amstrax/plugins/events/corrected_areas.py b/amstrax/plugins/events/corrected_areas.py index 2406e7b6..2eda864e 100644 --- a/amstrax/plugins/events/corrected_areas.py +++ b/amstrax/plugins/events/corrected_areas.py @@ -3,44 +3,90 @@ import strax export, __all__ = strax.exporter() -# @export -# @strax.takes_config( -# strax.Option( -# 's1_relative_lce_map', -# help="S1 relative LCE(x,y,z) map", -# default_by_run=[ -# (0, pax_file('XENON1T_s1_xyz_lce_true_kr83m_SR0_pax-680_fdc-3d_v0.json')), # noqa -# (first_sr1_run, pax_file('XENON1T_s1_xyz_lce_true_kr83m_SR1_pax-680_fdc-3d_v0.json'))]), # noqa -# strax.Option( -# 's2_relative_lce_map', -# help="S2 relative LCE(x, y) map", -# default_by_run=[ -# (0, pax_file('XENON1T_s2_xy_ly_SR0_24Feb2017.json')), -# (170118_1327, pax_file('XENON1T_s2_xy_ly_SR1_v2.2.json'))]), -# strax.Option( -# 'elife_file', -# default='https://raw.githubusercontent.com/XENONnT/strax_auxiliary_files/master/elife.npy', -# help='link to the electron lifetime')) -# class CorrectedAreas(strax.Plugin): -# depends_on = ['event_basics', 'event_positions'] -# dtype = [('cs1', np.float32, 'Corrected S1 area (PE)'), -# ('cs2', np.float32, 'Corrected S2 area (PE)')] -# -# def setup(self): -# self.s1_map = InterpolatingMap( -# get_resource(self.config['s1_relative_lce_map'])) -# self.s2_map = InterpolatingMap( -# get_resource(self.config['s2_relative_lce_map'])) -# # self.elife = get_elife(self.run_id,self.config['elife_file']) -# self.elife = 632e5 -# -# def compute(self, events): -# event_positions = np.vstack([events['x'], events['y'], events['z']]).T -# s2_positions = np.vstack([events['x_s2'], events['y_s2']]).T -# lifetime_corr = np.exp( -# events['drift_time'] / self.elife) -# -# return dict( -# cs1=events['s1_area'] / self.s1_map(event_positions), -# cs2=events['s2_area'] * lifetime_corr / self.s2_map(s2_positions)) +@export +class CorrectedAreas(strax.Plugin): + """ + Plugin which applies light collection efficiency maps and electron + life time to the data. + Computes the cS1/cS2 for the main/alternative S1/S2 as well as the + corrected life time. + Note: + Please be aware that for both, the main and alternative S1, the + area is corrected according to the xy-position of the main S2. + There are now 3 components of cS2s: cs2_top, cS2_bottom and cs2. + cs2_top and cs2_bottom are corrected by the corresponding maps, + and cs2 is the sum of the two. + """ + __version__ = '0.5.1' + + depends_on = ['event_basics', 'event_positions'] + + def infer_dtype(self): + dtype = [] + dtype += strax.time_fields + + for peak_type, peak_name in zip(['', 'alt_'], ['main', 'alternate']): + # Only apply + dtype += [ + (f'{peak_type}cs1', np.float32, f'Corrected area of {peak_name} S1 [PE]'), + ( + f'{peak_type}cs1_wo_timecorr', np.float32, + f'Corrected area of {peak_name} S1 (before LY correction) [PE]', + ), + ] + names = ['_wo_timecorr', '_wo_picorr', '_wo_elifecorr', ''] + descriptions = ['S2 xy', 'SEG/EE', 'photon ionization', 'elife'] + for i, name in enumerate(names): + if i == len(names) - 1: + description = '' + elif i == 0: + # special treatment for wo_timecorr, apply elife correction + description = ' (before ' + ' + '.join(descriptions[i + 1:-1]) + description += ', after ' + ' + '.join( + descriptions[:i + 1] + descriptions[-1:]) + ')' + else: + description = ' (before ' + ' + '.join(descriptions[i + 1:]) + description += ', after ' + ' + '.join(descriptions[:i + 1]) + ')' + dtype += [ + ( + f'{peak_type}cs2{name}', np.float32, + f'Corrected area of {peak_name} S2{description} [PE]', + ), + ( + f'{peak_type}cs2_area_fraction_top{name}', np.float32, + f'Fraction of area seen by the top PMT array for corrected ' + f'{peak_name} S2{description}', + ), + ] + return dtype + + def compute(self, events): + result = np.zeros(len(events), self.dtype) + result['time'] = events['time'] + result['endtime'] = events['endtime'] + + # S1 corrections depend on the actual corrected event position. + # We use this also for the alternate S1; for e.g. Kr this is + # fine as the S1 correction varies slowly. + event_positions = np.vstack([events['x'], events['y'], events['z']]).T + + for peak_type in ["", "alt_"]: + result[f"{peak_type}cs1"] = ( + result[f"{peak_type}cs1_wo_timecorr"] / 1) #self.rel_light_yield) + + # now can start doing corrections + for peak_type in ["", "alt_"]: + # S2(x,y) corrections use the observed S2 positions + s2_positions = np.vstack([events[f'{peak_type}s2_x'], events[f'{peak_type}s2_y']]).T + + # collect electron lifetime correction + # for electron lifetime corrections to the S2s, + # use drift time computed using the main S1. + el_string = peak_type + "s2_interaction_" if peak_type == "alt_" else peak_type + elife_correction = 1 #np.exp(events[f'{el_string}drift_time'] / self.elife) + + # apply electron lifetime correction + result[f"{peak_type}cs2"] = events[f"{peak_type}s2_area"] * elife_correction + + return result # \ No newline at end of file diff --git a/amstrax/plugins/events/event_basics.py b/amstrax/plugins/events/event_basics.py index 7445e89b..25008309 100644 --- a/amstrax/plugins/events/event_basics.py +++ b/amstrax/plugins/events/event_basics.py @@ -1,111 +1,377 @@ -import numba -import numpy as np import strax -export, __all__ = strax.exporter() +import numpy as np +import numba +import amstrax -@export -class EventBasics(strax.LoopPlugin): +export, __all__ = strax.exporter() +@strax.takes_config( + strax.Option('electron_drift_velocity', default=1.6e-4, track=True, + help='Vertical electron drift velocity in cm/ns (1e4 m/ms)'), + strax.Option('allow_posts2_s1s', + default=False, infer_type=False, + help="Allow S1s past the main S2 to become the main S1 and S2"), + strax.Option('force_main_before_alt', + default=False, infer_type=False, + help="Make the alternate S1 (and likewise S2) the main S1 if " + "occurs before the main S1."), + strax.Option('force_alt_s2_in_max_drift_time', + default=True, infer_type=False, + help="Make sure alt_s2 is in max drift time starting from main S1"), + strax.Option('event_s1_min_coincidence', + default=2, infer_type=False, + help="Event level S1 min coincidence. Should be >= s1_min_coincidence " + "in the peaklet classification"), + strax.Option('max_drift_length', + default=amstrax.tpc_z, infer_type=False, + help='Total length of the TPC from the bottom of gate to the ' + 'top of cathode wires [cm]'), +) +@export +class EventBasics(strax.Plugin): """ - TODO - """ + Computes the basic properties of the main/alternative S1/S2 within + an event. + The main S1 and alternative S1 are given by the largest two S1-Peaks + within the event. + The main S2 is given by the largest S2-Peak within the event, while + alternative S2 is selected as the largest S2 other than main S2 + in the time window [main S1 time, main S1 time + max drift time]. + """ + __version__ = '1.3.3' - rechunk_on_save = False - __version__ = '0.0.21' - # Peak Positions temporarily taken out - # n_competing within peak_basics depends_on = ('events', - 'peak_basics', 'peak_classification',) - - # 'peak_positions') #n_competing - + 'peak_basics', + 'peak_positions', + 'peak_proximity') provides = 'event_basics' - + data_kind = 'events' + loop_over = 'events' def infer_dtype(self): - dtype = [(('Number of peaks in the event', - 'n_peaks'), np.int32), - (('Drift time between main S1 and S2 in ns', - 'drift_time'), np.int64), - ('time', np.int64, 'Event start time in ns since the unix epoch'), - ('endtime', np.int64, 'Event end time in ns since the unix epoch') - ] - for i in [1, 2]: - dtype += [((f'Main S{i} peak index', - f's{i}_index'), np.int32), - ((f'Main S{i} area (PE), uncorrected', - f's{i}_area'), np.float32), - ((f'Main S{i} area fraction top', - f's{i}_area_fraction_top'), np.float32), - ((f'Main S{i} width (ns, 50% area)', - f's{i}_range_50p_area'), np.float32), - ((f'Main S{i} number of competing peaks', - f's{i}_n_competing'), np.int32)] - # dtype += [(f'x_s2', np.float32, - # f'Main S2 reconstructed X position (cm), uncorrected',), - # (f'y_s2', np.float32, - # f'Main S2 reconstructed Y position (cm), uncorrected',)] - dtype += [(f's2_largest_other', np.float32, - f'Largest other S2 area (PE) in event, uncorrected',), - (f's1_largest_other', np.float32, - f'Largest other S1 area (PE) in event, uncorrected',), - (f'alt_s1_interaction_drift_time', np.float32, - f'Drift time with alternative s1',) + # Basic event properties + self._set_posrec_save() + self._set_dtype_requirements() + dtype = [] + dtype += strax.time_fields + dtype += [('n_peaks', np.int32, + 'Number of peaks in the event'), + ('drift_time', np.float32, + 'Drift time between main S1 and S2 in ns'), + ('event_number', np.int64, + 'Event number in this dataset'), ] + dtype += self._get_si_dtypes(self.peak_properties) + + dtype += [ + (f's1_max_diff', np.int32, + f'Main S1 largest time difference between apexes of hits [ns]'), + (f'alt_s1_max_diff', np.int32, + f'Alternate S1 largest time difference between apexes of hits [ns]'), + (f's1_min_diff', np.int32, + f'Main S1 smallest time difference between apexes of hits [ns]'), + (f'alt_s1_min_diff', np.int32, + f'Alternate S1 smallest time difference between apexes of hits [ns]'), + ] + + dtype += [ + (f's2_x', np.float32, + f'Main S2 reconstructed X position, uncorrected [cm]'), + (f's2_y', np.float32, + f'Main S2 reconstructed Y position, uncorrected [cm]'), + (f'alt_s2_x', np.float32, + f'Alternate S2 reconstructed X position, uncorrected [cm]'), + (f'alt_s2_y', np.float32, + f'Alternate S2 reconstructed Y position, uncorrected [cm]'), + (f'area_before_main_s2', np.float32, + f'Sum of areas before Main S2 [PE]'), + (f'large_s2_before_main_s2', np.float32, + f'The largest S2 before the Main S2 [PE]') + ] + + dtype += self._get_posrec_dtypes() return dtype - def compute_loop(self, event, peaks): - result = dict(n_peaks=len(peaks)) - if not len(peaks): - return result - - main_s = dict() - for s_i in [2, 1]: - s_mask = peaks['type'] == s_i - - # For determining the main S1, remove all peaks - # after the main S2 (if there was one) - # This is why S2 finding happened first - if s_i == 1 and result[f's2_index'] != -1: - s_mask &= peaks['time'] < main_s[2]['time'] - - ss = peaks[s_mask] - s_indices = np.arange(len(peaks))[s_mask] - - if not len(ss): - result[f's{s_i}_index'] = -1 - continue - - main_i = np.argmax(ss['area']) - # Find largest other signals - if s_i == 2 and ss['n_competing'][main_i] > 0 and len(ss['area']) > 1: - s2_second_i = np.argsort(ss['area'])[-2] - result[f's2_largest_other'] = ss['area'][s2_second_i] - - if s_i == 1 and ss['n_competing'][main_i] > 0 and len(ss['area']) > 1: - s1_second_i = np.argsort(ss['area'])[-2] - result[f's1_largest_other'] = ss['area'][s1_second_i] - - result[f's{s_i}_index'] = s_indices[main_i] - s = main_s[s_i] = ss[main_i] - - for prop in ['area', 'area_fraction_top', - 'range_50p_area', 'n_competing']: - result[f's{s_i}_{prop}'] = s[prop] - # if s_i == 2: - # result['x_s2'] = s['xr'] - # result['y_s2'] = s['yr'] - - # Compute a drift time only if we have a valid S1-S2 pairs - if len(main_s) == 2: - result['drift_time'] = main_s[2]['time'] - main_s[1]['time'] - # Compute alternative drift time - if 's1_second_i' in locals(): - result['alt_s1_interaction_drift_time'] = main_s[2]['time'] - ss['time'][ - s1_second_i] + def _set_dtype_requirements(self): + """Needs to be run before inferring dtype as it is needed there""" + # Properties to store for each peak (main and alternate S1 and S2) + self.peak_properties = ( + ('time', np.int64, 'start time since unix epoch [ns]'), + ('center_time', np.int64, 'weighted center time since unix epoch [ns]'), + ('endtime', np.int64, 'end time since unix epoch [ns]'), + ('area', np.float32, 'area, uncorrected [PE]'), + ('n_channels', np.int16, 'count of contributing PMTs'), + ('n_hits', np.int16, 'count of hits contributing at least one sample to the peak'), + ('n_competing', np.int32, 'number of competing peaks'), + ('max_pmt', np.int16, 'PMT number which contributes the most PE'), + ('max_pmt_area', np.float32, 'area in the largest-contributing PMT (PE)'), + ('range_50p_area', np.float32, 'width, 50% area [ns]'), + ('range_90p_area', np.float32, 'width, 90% area [ns]'), + ('rise_time', np.float32, 'time between 10% and 50% area quantiles [ns]'), + ('area_fraction_top', np.float32, 'fraction of area seen by the top PMT array'), + ('tight_coincidence', np.int16, 'Channel within tight range of mean'), + ('n_saturated_channels', np.int16, 'Total number of saturated channels'), + ) + + def setup(self): + + self.electron_drift_velocity = self.config['electron_drift_velocity'] + self.allow_posts2_s1s = self.config['allow_posts2_s1s'] + self.force_main_before_alt = self.config['force_main_before_alt'] + self.force_alt_s2_in_max_drift_time = self.config['force_alt_s2_in_max_drift_time'] + self.event_s1_min_coincidence = self.config['event_s1_min_coincidence'] + self.max_drift_length = self.config['max_drift_length'] + + self.drift_time_max = int(self.max_drift_length / self.electron_drift_velocity) + + + + @staticmethod + def _get_si_dtypes(peak_properties): + """Get properties for S1/S2 from peaks directly""" + si_dtype = [] + for s_i in [1, 2]: + # Peak indices + si_dtype += [ + (f's{s_i}_index', np.int32, f'Main S{s_i} peak index in event'), + (f'alt_s{s_i}_index', np.int32, f'Alternate S{s_i} peak index in event')] + + # Peak properties + for name, dt, comment in peak_properties: + si_dtype += [(f's{s_i}_{name}', dt, f'Main S{s_i} {comment}'), + (f'alt_s{s_i}_{name}', dt, f'Alternate S{s_i} {comment}')] + + # Drifts and delays + si_dtype += [ + (f'alt_s{s_i}_interaction_drift_time', np.float32, + f'Drift time using alternate S{s_i} [ns]'), + (f'alt_s{s_i}_delay', np.int32, + f'Time between main and alternate S{s_i} [ns]')] + return si_dtype + + def _set_posrec_save(self): + """ + parse x_mlp et cetera if needed to get the algorithms used and + set required class attributes + """ + posrec_fields = self.deps['peak_positions'].dtype_for('peak_positions').names + posrec_names = [d.split('_')[-1] for d in posrec_fields if 'x_' in d] + + # Preserve order. "set" is not ordered and dtypes should always be ordered + self.pos_rec_labels = list(set(posrec_names)) + self.pos_rec_labels.sort() + + self.posrec_save = [(xy + algo) + for xy in ['x_', 'y_'] + for algo in self.pos_rec_labels] + + def _get_posrec_dtypes(self): + """Get S2 positions for each of the position reconstruction algorithms""" + posrec_dtpye = [] + + for algo in self.pos_rec_labels: + # S2 positions + posrec_dtpye += [ + (f's2_x_{algo}', np.float32, + f'Main S2 {algo}-reconstructed X position, uncorrected [cm]'), + (f's2_y_{algo}', np.float32, + f'Main S2 {algo}-reconstructed Y position, uncorrected [cm]'), + (f'alt_s2_x_{algo}', np.float32, + f'Alternate S2 {algo}-reconstructed X position, uncorrected [cm]'), + (f'alt_s2_y_{algo}', np.float32, + f'Alternate S2 {algo}-reconstructed Y position, uncorrected [cm]')] + return posrec_dtpye + + @staticmethod + def set_nan_defaults(buffer): + """ + When constructing the dtype, take extra care to set values to + np.Nan / -1 (for ints) as 0 might have a meaning + """ + for field in buffer.dtype.names: + if np.issubdtype(buffer.dtype[field], np.integer): + buffer[field][:] = -1 + else: + buffer[field][:] = np.nan + + def compute(self, events, peaks): + result = np.zeros(len(events), dtype=self.dtype) + self.set_nan_defaults(result) + + split_peaks = strax.split_by_containment(peaks, events) + + result['time'] = events['time'] + result['endtime'] = events['endtime'] + result['event_number'] = events['event_number'] + + self.fill_events(result, events, split_peaks) return result + # If copy_largest_peaks_into_event is ever numbafied, also numbafy this function + def fill_events(self, result_buffer, events, split_peaks): + """Loop over the events and peaks within that event""" + for event_i, _ in enumerate(events): + peaks_in_event_i = split_peaks[event_i] + n_peaks = len(peaks_in_event_i) + result_buffer[event_i]['n_peaks'] = n_peaks + + if not n_peaks: + raise ValueError(f'No peaks within event?\n{events[event_i]}') + + self.fill_result_i(result_buffer[event_i], peaks_in_event_i) + + def fill_result_i(self, event, peaks): + """For a single event with the result_buffer""" + # Consider S2s first, then S1s (to enable allow_posts2_s1s = False) + # number_of_peaks=0 selects all available s2 and sort by area + largest_s2s, s2_idx = self.get_largest_sx_peaks(peaks, s_i=2, number_of_peaks=0) + + if not self.allow_posts2_s1s and len(largest_s2s): + s1_latest_time = largest_s2s[0]['time'] + else: + s1_latest_time = np.inf + + largest_s1s, s1_idx = self.get_largest_sx_peaks( + peaks, + s_i=1, + s1_before_time=s1_latest_time, + s1_min_coincidence=self.event_s1_min_coincidence) + + if self.force_alt_s2_in_max_drift_time: + s2_idx, largest_s2s = self.find_main_alt_s2(largest_s1s, + s2_idx, + largest_s2s, + self.drift_time_max, + ) + else: + # Select only the largest two S2s + largest_s2s, s2_idx = largest_s2s[0:2], s2_idx[0:2] + + if self.force_main_before_alt: + s2_order = np.argsort(largest_s2s['time']) + largest_s2s = largest_s2s[s2_order] + s2_idx = s2_idx[s2_order] + + self.set_sx_index(event, s1_idx, s2_idx) + self.set_event_properties(event, largest_s1s, largest_s2s, peaks) + + # Loop over S1s and S2s and over main / alt. + for s_i, largest_s_i in enumerate([largest_s1s, largest_s2s], 1): + # Largest index 0 -> main sx, 1 -> alt sx + for largest_index, main_or_alt in enumerate(['s', 'alt_s']): + peak_properties_to_save = [name for name, _, _ in self.peak_properties] + if s_i == 1: + peak_properties_to_save += ['max_diff', 'min_diff'] + elif s_i == 2: + peak_properties_to_save += ['x', 'y'] + peak_properties_to_save += self.posrec_save + field_names = [f'{main_or_alt}{s_i}_{name}' for name in peak_properties_to_save] + self.copy_largest_peaks_into_event(event, + largest_s_i, + largest_index, + field_names, + peak_properties_to_save) + + @staticmethod + @numba.njit + def find_main_alt_s2(largest_s1s, s2_idx, largest_s2s, drift_time_max): + """Require alt_s2 happens between main S1 and maximum drift time""" + if len(largest_s1s) > 0 and len(largest_s2s) > 1: + # If there is a valid s1-s2 pair and has a second s2, then check alt s2 validity + s2_after_s1 = largest_s2s['center_time'] > largest_s1s[0]['center_time'] + s2_before_max_drift_time = (largest_s2s['center_time'] + - largest_s1s[0]['center_time']) < 1.01 * drift_time_max + mask = s2_after_s1 & s2_before_max_drift_time + # The selection avoids main_S2 + mask[0] = True + # Take main and the largest valid alt_S2 + s2_idx, largest_s2s = s2_idx[mask], largest_s2s[mask] + return s2_idx[:2], largest_s2s[:2] + + @staticmethod + @numba.njit + def set_event_properties(result, largest_s1s, largest_s2s, peaks): + """Get properties like drift time and area before main S2""" + # Compute drift times only if we have a valid S1-S2 pair + if len(largest_s1s) > 0 and len(largest_s2s) > 0: + result['drift_time'] = largest_s2s[0]['center_time'] - largest_s1s[0]['center_time'] + if len(largest_s1s) > 1: + result['alt_s1_interaction_drift_time'] = largest_s2s[0]['center_time'] - largest_s1s[1]['center_time'] + result['alt_s1_delay'] = largest_s1s[1]['center_time'] - largest_s1s[0]['center_time'] + if len(largest_s2s) > 1: + result['alt_s2_interaction_drift_time'] = largest_s2s[1]['center_time'] - largest_s1s[0]['center_time'] + result['alt_s2_delay'] = largest_s2s[1]['center_time'] - largest_s2s[0]['center_time'] + + # areas before main S2 + if len(largest_s2s): + peaks_before_ms2 = peaks[peaks['time'] < largest_s2s[0]['time']] + result['area_before_main_s2'] = np.sum(peaks_before_ms2['area']) + + s2peaks_before_ms2 = peaks_before_ms2[peaks_before_ms2['type'] == 2] + if len(s2peaks_before_ms2) == 0: + result['large_s2_before_main_s2'] = 0 + else: + result['large_s2_before_main_s2'] = np.max(s2peaks_before_ms2['area']) + return result + + @staticmethod + # @numba.njit <- works but slows if fill_events is not numbafied + def get_largest_sx_peaks(peaks, + s_i, + s1_before_time=np.inf, + s1_min_coincidence=0, + number_of_peaks=2): + """Get the largest S1/S2. For S1s allow a min coincidence and max time""" + # Find all peaks of this type (S1 or S2) + s_mask = peaks['type'] == s_i + if s_i == 1: + s_mask &= peaks['time'] < s1_before_time + s_mask &= peaks['tight_coincidence'] >= s1_min_coincidence + + selected_peaks = peaks[s_mask] + s_index = np.arange(len(peaks))[s_mask] + largest_peaks = np.argsort(selected_peaks['area'])[-number_of_peaks:][::-1] + return selected_peaks[largest_peaks], s_index[largest_peaks] + + # If only we could numbafy this... Unfortunatly we cannot. + # Perhaps we could one day consider doing something like strax.copy_to_buffer + @staticmethod + def copy_largest_peaks_into_event(result, + largest_s_i, + main_or_alt_index, + result_fields, + peak_fields, + ): + """ + For one event, write all the peak_fields (e.g. "area") of the peak + (largest_s_i) into their associated field in the event (e.g. s1_area), + main_or_alt_index differentiates between main (index 0) and alt (index 1) + """ + index_not_in_list_of_largest_peaks = main_or_alt_index >= len(largest_s_i) + if index_not_in_list_of_largest_peaks: + # There is no such peak. E.g. main_or_alt_index == 1 but largest_s_i = ["Main S1"] + # Asking for index 1 doesn't work on a len 1 list of peaks. + return + + for i, ev_field in enumerate(result_fields): + p_field = peak_fields[i] + if p_field not in ev_field: + raise ValueError("Event fields must derive from the peak fields") + result[ev_field] = largest_s_i[main_or_alt_index][p_field] + + @staticmethod + # @numba.njit <- works but slows if fill_events is not numbafied + def set_sx_index(res, s1_idx, s2_idx): + if len(s1_idx): + res['s1_index'] = s1_idx[0] + if len(s1_idx) > 1: + res['alt_s1_index'] = s1_idx[1] + if len(s2_idx): + res['s2_index'] = s2_idx[0] + if len(s2_idx) > 1: + res['alt_s2_index'] = s2_idx[1] + diff --git a/amstrax/plugins/events/event_info.py b/amstrax/plugins/events/event_info.py index f3884e20..c9a1ba8a 100644 --- a/amstrax/plugins/events/event_info.py +++ b/amstrax/plugins/events/event_info.py @@ -5,8 +5,8 @@ @export class EventInfo(strax.MergeOnlyPlugin): - depends_on = ['events', - 'event_basics', + depends_on = ['event_basics', + 'corrected_areas', 'event_positions', # 'energy_estimates', ] diff --git a/amstrax/plugins/events/event_positions.py b/amstrax/plugins/events/event_positions.py index a17a5401..08a21d8f 100644 --- a/amstrax/plugins/events/event_positions.py +++ b/amstrax/plugins/events/event_positions.py @@ -1,66 +1,140 @@ -import numba import numpy as np +import amstrax + + +DEFAULT_POSREC_ALGO = 'lpf' + import strax + export, __all__ = strax.exporter() + @export -class EventPositions(strax.LoopPlugin): - depends_on = ('events', 'event_basics', 'peaks', 'peak_classification') - rechunk_on_save = False - dtype = [ - ('xr', np.float32, - 'Interaction x-position'), - ('yr', np.float32, - 'Interaction y-position'), - ('time', np.int64, 'Event start time in ns since the unix epoch'), - ('endtime', np.int64, 'Event end time in ns since the unix epoch') - ] - __version__ = '0.0.4' +@strax.takes_config( + strax.Option('electron_drift_velocity', + default=0.00016, + help='Vertical electron drift velocity in cm/ns (1e4 m/ms)'), + + strax.Option('electron_drift_time_gate', + default=1, + help='Electron drift time from the gate in ns'), + strax.Option('default_reconstruction_algorithm', + default=DEFAULT_POSREC_ALGO, + help="default reconstruction algorithm that provides (x,y)"), +) + + +class EventPositions(strax.Plugin): + """ + Computes the observed and corrected position for the main S1/S2 + pairs in an event. For XENONnT data, it returns the FDC corrected + positions of the default_reconstruction_algorithm. In case the fdc_map + is given as a file (not through CMT), then the coordinate system + should be given as (x, y, z), not (x, y, drift_time). + """ + + depends_on = ('event_basics',) + + __version__ = '0.3.0' + + + def infer_dtype(self): + dtype = [] + for j in 'x y r'.split(): + comment = f'Main interaction {j}-position, field-distortion corrected (cm)' + dtype += [(j, np.float32, comment)] + for s_i in [1, 2]: + comment = f'Alternative S{s_i} interaction (rel. main S{3 - s_i}) {j}-position, field-distortion corrected (cm)' + field = f'alt_s{s_i}_{j}_fdc' + dtype += [(field, np.float32, comment)] + + for j in ['z']: + comment = 'Interaction z-position, corrected to non-uniform drift velocity (cm)' + dtype += [(j, np.float32, comment)] + comment = 'Interaction z-position, corrected to non-uniform drift velocity, duplicated (cm)' + dtype += [(j + "_dv_corr", np.float32, comment)] + for s_i in [1, 2]: + comment = f'Alternative S{s_i} z-position (rel. main S{3 - s_i}), corrected to non-uniform drift velocity (cm)' + field = f'alt_s{s_i}_z' + dtype += [(field, np.float32, comment)] + # values for corrected Z position + comment = f'Alternative S{s_i} z-position (rel. main S{3 - s_i}), corrected to non-uniform drift velocity, duplicated (cm)' + field = f'alt_s{s_i}_z_dv_corr' + dtype += [(field, np.float32, comment)] + + + naive_pos = [] + fdc_pos = [] + for j in 'r z'.split(): + naive_pos += [(f'{j}_naive', + np.float32, + f'Main interaction {j}-position with observed position (cm)')] + fdc_pos += [(f'{j}_field_distortion_correction', + np.float32, + f'Correction added to {j}_naive for field distortion (cm)')] + for s_i in [1, 2]: + naive_pos += [( + f'alt_s{s_i}_{j}_naive', + np.float32, + f'Alternative S{s_i} interaction (rel. main S{3 - s_i}) {j}-position with observed position (cm)')] + fdc_pos += [(f'alt_s{s_i}_{j}_field_distortion_correction', + np.float32, + f'Correction added to alt_s{s_i}_{j}_naive for field distortion (cm)')] + dtype += naive_pos + fdc_pos + for s_i in [1, 2]: + dtype += [(f'alt_s{s_i}_theta', + np.float32, + f'Alternative S{s_i} (rel. main S{3 - s_i}) interaction angular position (radians)')] + + dtype += [('theta', np.float32, f'Main interaction angular position (radians)')] + return dtype + strax.time_fields def setup(self): - # z position of the in-plane SiPMs - z_plane = 10 - # radius of the cyinder for SiPMs at the side - r_cylinder = 22 - # radius of a SiPM - I assume circular SiPMs with a radius to make the area correspond to a 3x3mm2 square. - r_sipm = 1.6925 - # build geometry - geo = GeoParameters(z_plane=z_plane, r_cylinder=r_cylinder, r_sipm=r_sipm) - - sipm = SiPM(type="plane", position=[0, -15, z_plane], qeff=0.25) - geo.add_sipm(sipm) - sipm = SiPM(type="plane", position=[-13, -7.5, z_plane], qeff=0.25) - geo.add_sipm(sipm) - sipm = SiPM(type="plane", position=[13, -7.5, z_plane], qeff=0.25) - geo.add_sipm(sipm) - sipm = SiPM(type="plane", position=[-4, 0, z_plane], qeff=0.25) - geo.add_sipm(sipm) - sipm = SiPM(type="plane", position=[4, 0, z_plane], qeff=0.25) - geo.add_sipm(sipm) - sipm = SiPM(type="plane", position=[-13, 7.5, z_plane], qeff=0.25) - geo.add_sipm(sipm) - sipm = SiPM(type="plane", position=[13, 7.5, z_plane], qeff=0.25) - geo.add_sipm(sipm) - - self.geo = geo - - def compute_loop(self, events, peaks): - result = dict() - - if not len(peaks): - return result - - s2_index = events['s2_index'] - if s2_index == -1 or s2_index > len(peaks[(peaks['type'] == 2)]) - 1: - return result - - s2_peak = peaks[(peaks['type'] == 2)][s2_index] - for i, area in enumerate(s2_peak['area_per_channel'][:7]): - self.geo.sipms[i].set_number_of_hits(area) - - posrec = Reconstruction(self.geo) - pos = posrec.reconstruct_position('LNLIKE') - for key in ['xr', 'yr']: - result[key] = pos[key] - return result + self.electron_drift_velocity = self.config['electron_drift_velocity'] + self.electron_drift_time_gate = self.config['electron_drift_time_gate'] + self.default_reconstruction_algorithm = self.config['default_reconstruction_algorithm'] + + self.coordinate_scales = [1., 1., - self.electron_drift_velocity] + # self.map = self.fdc_map + + def compute(self, events): + + result = {'time': events['time'], + 'endtime': strax.endtime(events)} + + # s_i == 0 indicates the main event, while s_i != 0 means alternative S1 or S2 is used based on s_i value + # while the other peak is the main one (e.g., s_i == 1 means that the event is defined using altS1 and main S2) + for s_i in [0, 1, 2]: + # alt_sx_interaction_drift_time is calculated between main Sy and alternative Sx + drift_time = events['drift_time'] if not s_i else events[f'alt_s{s_i}_interaction_drift_time'] + z_obs = - self.electron_drift_velocity * drift_time + xy_pos = 's2_' if s_i != 2 else 'alt_s2_' + orig_pos = np.vstack([events[f'{xy_pos}x'], events[f'{xy_pos}y'], z_obs]).T + r_obs = np.linalg.norm(orig_pos[:, :2], axis=1) + z_obs = z_obs + self.electron_drift_velocity * self.electron_drift_time_gate + + # apply z bias correction + z_dv_delta = 0 # self.z_bias_map(np.array([r_obs, z_obs]).T, map_name='z_bias_map') + corr_pos = np.vstack([events[f"{xy_pos}x"], events[f"{xy_pos}y"], z_obs - z_dv_delta]).T + # apply r bias correction + delta_r = 0 # self.map(corr_pos) + with np.errstate(invalid='ignore', divide='ignore'): + r_cor = r_obs + delta_r + scale = np.divide(r_cor, r_obs, out=np.zeros_like(r_cor), where=r_obs != 0) + + pre_field = '' if s_i == 0 else f'alt_s{s_i}_' + post_field = '' if s_i == 0 else '_fdc' + result.update({f'{pre_field}x{post_field}': orig_pos[:, 0] * scale, + f'{pre_field}y{post_field}': orig_pos[:, 1] * scale, + f'{pre_field}r{post_field}': r_cor, + f'{pre_field}r_naive': r_obs, + f'{pre_field}r_field_distortion_correction': delta_r, + f'{pre_field}theta': np.arctan2(orig_pos[:, 1], orig_pos[:, 0]), + f'{pre_field}z_naive': z_obs, + # using z_dv_corr (z_obs - z_dv_delta) in agreement with the dtype description + # the FDC for z (z_cor) is found to be not reliable (see #527) + f'{pre_field}z': z_obs - z_dv_delta, + f'{pre_field}z_dv_corr': z_obs - z_dv_delta, + }) + return result diff --git a/amstrax/plugins/events/events.py b/amstrax/plugins/events/events.py index 8aef7643..47a20213 100644 --- a/amstrax/plugins/events/events.py +++ b/amstrax/plugins/events/events.py @@ -19,7 +19,9 @@ 'triggering peak'), ) class Events(strax.OverlapWindowPlugin): - depends_on = ['peaks', 'peak_basics'] # peak_basics instead of n_competing + depends_on = ['peaks', + 'peak_basics', + 'peak_proximity'] # peak_basics instead of n_competing rechunk_on_save = False data_kind = 'events' parallel = False diff --git a/amstrax/plugins/led_cal/led_calibration.py b/amstrax/plugins/led_cal/led_calibration.py deleted file mode 100644 index 76d401c7..00000000 --- a/amstrax/plugins/led_cal/led_calibration.py +++ /dev/null @@ -1,147 +0,0 @@ -from immutabledict import immutabledict -import strax -import numba -import numpy as np - -export, __all__ = strax.exporter() - -@export -@strax.takes_config( - strax.Option( - 'led_window', - default=(80, 110), - help="Window (samples) where we expect the signal in LED calibration"), - strax.Option( - 'noise_window', - default=(0, 10), - help="Window (samples) to analysis the noise"), -) -class LEDCalibration(strax.Plugin): - """ - Preliminary version, several parameters to set during commissioning. - LEDCalibration returns: channel, time, dt, length, Area, - amplitudeLED and amplitudeNOISE. - The new variables are: - - Area: Area computed in the given window, averaged over 6 - windows that have the same starting sample and different end - samples. - - amplitudeLED: peak amplitude of the LED on run in the given - window. - - amplitudeNOISE: amplitude of the LED on run in a window far - from the signal one. - """ - - __version__ = '1.0.0' - - depends_on = ('records_led',) - data_kind = 'led_cal' - compressor = 'zstd' - parallel = 'process' - rechunk_on_save = False - - dtype = [('area', np.float32, 'Area averaged in integration windows'), - ('amplitude_led', np.float32, 'Amplitude in LED window'), - ('amplitude_noise', np.float32, 'Amplitude in off LED window'), - ('channel', np.int16, 'Channel'), - ('time', np.int64, 'Start time of the interval (ns since unix epoch)'), - ('dt', np.int16, 'Time resolution in ns'), - ('length', np.int32, 'Length of the interval in samples')] - - - def setup(self): - - self.led_window = self.config['led_window'] - self.noise_window = self.config['noise_window'] - - def compute(self, records_led): - ''' - The data for LED calibration are build for those PMT which belongs to channel list. - This is used for the different ligh levels. As defaul value all the PMTs are considered. - ''' - - r = records_led - - temp = np.zeros(len(r), dtype=self.dtype) - strax.copy_to_buffer(r, temp, "_recs_to_temp_led") - - on, off = get_amplitude(r, self.led_window, self.noise_window) - temp['amplitude_led'] = on['amplitude'] - temp['amplitude_noise'] = off['amplitude'] - - area = get_area(r, self.led_window) - temp['area'] = area['area'] - return temp - - -# def get_records(raw_records, baseline_window, record_i_signal): -# """ -# Determine baseline as the average of the first baseline_samples -# of each pulse. Subtract the pulse float(data) from baseline. -# """ - -# record_length = np.shape(raw_records.dtype['data'])[0] - -# _dtype = [(('Start time since unix epoch [ns]', 'time'), ' 0).sum(axis=1) + r['n_hits'] = p['n_hits'] r['range_50p_area'] = p['width'][:, 5] + r['range_90p_area'] = p['width'][:, 9] r['max_pmt'] = np.argmax(p['area_per_channel'], axis=1) r['max_pmt_area'] = np.max(p['area_per_channel'], axis=1) + r['tight_coincidence'] = p['tight_coincidence'] + r['n_saturated_channels'] = p['n_saturated_channels'] - # area_top = p['area_per_channel'][:, :8].sum(axis=1) - area_top = p['area_per_channel'][:, 1:2].sum(axis=1) # top pmt in ch 1 - # Negative-area peaks get 0 AFT - TODO why not NaN? + n_top = self.config["n_top_pmts"] + area_top = p['area_per_channel'][:, :n_top].sum(axis=1) + # Recalculate to prevent numerical inaccuracy #442 + area_total = p['area_per_channel'].sum(axis=1) + # Negative-area peaks get NaN AFT m = p['area'] > 0 - r['area_fraction_top'][m] = area_top[m] / p['area'][m] - # n_competing temporarily due to chunking issues - r['n_competing'] = self.find_n_competing( - peaks, - window=self.config['nearby_window'], - fraction=self.config['min_area_fraction']) + r['area_fraction_top'][m] = area_top[m] / area_total[m] + r['area_fraction_top'][~m] = float('nan') + r['rise_time'] = -p['area_decile_from_midpoint'][:, 1] + + if self.config['check_peak_sum_area_rtol'] is not None: + self.check_area(area_total, p, self.config['check_peak_sum_area_rtol']) + # Negative or zero-area peaks have centertime at startime + r['center_time'] = p['time'] + r['center_time'][m] += self.compute_center_times(peaks[m]) return r # n_competing def get_window_size(self): - return 2 * self.config['nearby_window'] + return 2 * self.config["nearby_window"] @staticmethod @numba.jit(nopython=True, nogil=True, cache=False) def find_n_competing(peaks, window, fraction): n = len(peaks) - t = peaks['time'] - a = peaks['area'] + t = peaks["time"] + a = peaks["area"] results = np.zeros(n, dtype=np.int16) left_i = 0 right_i = 0 @@ -88,7 +131,58 @@ def find_n_competing(peaks, window, fraction): left_i += 1 while t[right_i] - window < t[i] and right_i < n - 1: right_i += 1 - results[i] = np.sum(a[left_i:right_i + 1] > a[i] * fraction) + results[i] = np.sum(a[left_i : right_i + 1] > a[i] * fraction) return results + @staticmethod + @numba.njit(cache=True, nogil=True) + def compute_center_times(peaks): + result = np.zeros(len(peaks), dtype=np.int32) + for p_i, p in enumerate(peaks): + t = 0 + for t_i, weight in enumerate(p["data"]): + t += t_i * p["dt"] * weight + result[p_i] = t / p["area"] + return result + + @staticmethod + def check_area(area_per_channel_sum, peaks, rtol) -> None: + """ + Check if the area of the sum-wf is the same as the total area + (if the area of the peak is positively defined). + + :param area_per_channel_sum: the summation of the + peaks['area_per_channel'] which will be checked against the + values of peaks['area']. + :param peaks: array of peaks. + :param rtol: relative tolerance for difference between + area_per_channel_sum and peaks['area']. See np.isclose. + :raises: ValueError if the peak area and the area-per-channel + sum are not sufficiently close + """ + positive_area = peaks["area"] > 0 + if not np.sum(positive_area): + return + + is_close = np.isclose( + area_per_channel_sum[positive_area], + peaks[positive_area]["area"], + rtol=rtol, + ) + + if not is_close.all(): + for peak in peaks[positive_area][~is_close]: + print("bad area") + strax.print_record(peak) + + p_i = np.where(~is_close)[0][0] + peak = peaks[positive_area][p_i] + area_fraction_off = ( + 1 - area_per_channel_sum[positive_area][p_i] / peak["area"] + ) + message = ( + f"Area not calculated correctly, it's " + f'{100 * area_fraction_off} % off, time: {peak["time"]}' + ) + raise ValueError(message) diff --git a/amstrax/plugins/peaks/peak_positions.py b/amstrax/plugins/peaks/peak_positions.py index 11e46f05..56ae715e 100644 --- a/amstrax/plugins/peaks/peak_positions.py +++ b/amstrax/plugins/peaks/peak_positions.py @@ -3,82 +3,28 @@ import strax export, __all__ = strax.exporter() -# from amstrax.SiPMdata import * -# move this to legacy once you have the new peak_positions.py in amstrax -# @export -# class PeakPositions(strax.Plugin): -# depends_on = ('peaks', 'peak_classification') -# rechunk_on_save = False -# __version__ = '0.0.34' # .33 for LNLIKE -# dtype = [ -# ('xr', np.float32, -# 'Interaction x-position'), -# ('yr', np.float32, -# 'Interaction y-position'), -# ('r', np.float32, -# 'radial distance'), -# ('time', np.int64, 'Start time of the peak (ns since unix epoch)'), -# ('endtime', np.int64, 'End time of the peak (ns since unix epoch)') -# ] - -# def setup(self): -# # z position of the in-plane SiPMs -# z_plane = 10 -# # radius of the cylinder for SiPMs at the side -# r_cylinder = 22 -# # radius of a SiPM - I assume circular SiPMs with a radius to make the area correspond to a 3x3mm2 square. -# r_sipm = 1.6925 -# # build geometry -# geo = GeoParameters(z_plane=z_plane, r_cylinder=r_cylinder, r_sipm=r_sipm) - -# sipm = SiPM(type="plane", position=[0, -15, z_plane], qeff=0.25) -# geo.add_sipm(sipm) -# sipm = SiPM(type="plane", position=[-13, -7.5, z_plane], qeff=0.25) -# geo.add_sipm(sipm) -# # sipm = SiPM(type="plane", position=[0, 15, z_plane], qeff=0.25) -# # geo.add_sipm(sipm) -# sipm = SiPM(type="plane", position=[13, -7.5, z_plane], qeff=0.25) -# geo.add_sipm(sipm) -# sipm = SiPM(type="plane", position=[-4, 0, z_plane], qeff=0.25) -# geo.add_sipm(sipm) -# sipm = SiPM(type="plane", position=[4, 0, z_plane], qeff=0.25) -# geo.add_sipm(sipm) -# sipm = SiPM(type="plane", position=[-13, 7.5, z_plane], qeff=0.25) -# geo.add_sipm(sipm) -# sipm = SiPM(type="plane", position=[13, 7.5, z_plane], qeff=0.25) -# geo.add_sipm(sipm) - -# self.geo = geo - -# def compute(self, peaks): - -# result = np.empty(len(peaks), dtype=self.dtype) - -# if not len(peaks): -# return result - -# for ix, p in enumerate(peaks): - -# if p['type'] != 2: -# continue - -# # if [X] channel is not working -# k = np.delete(p['area_per_channel'], [2]) -# for i, area in enumerate(k): -# self.geo.sipms[i].set_number_of_hits(area) - -# # if all 8 channels are working -# # for i, area in enumerate(p['area_per_channel']): -# # self.geo.sipms[i].set_number_of_hits(area) - -# posrec = Reconstruction(self.geo) -# pos = posrec.reconstruct_position('CHI2') -# for key in ['xr', 'yr']: -# result[key][ix] = pos[key] - -# for q in ['time', 'endtime']: -# result[q] = p[q] - -# result['r'] = (result['xr'] ** 2 + result['yr'] ** 2) ** (1 / 2) -# return result +@export +class PeakPositions(strax.Plugin): + depends_on = ('peaks', 'peak_classification') + rechunk_on_save = False + __version__ = '0.0.34' # .33 for LNLIKE + dtype = [ + ('x_lpf', np.float32, + 'Interaction x-position'), + ('y_lpf', np.float32, + 'Interaction y-position'), + ('r', np.float32, + 'radial distance'), + ('time', np.int64, 'Start time of the peak (ns since unix epoch)'), + ('endtime', np.int64, 'End time of the peak (ns since unix epoch)') + ] + + + def compute(self, peaks): + + result = np.empty(len(peaks), dtype=self.dtype) + result['time'] = peaks['time'] + result['endtime'] = peaks['endtime'] + + return result diff --git a/amstrax/plugins/peaks/peak_proximity.py b/amstrax/plugins/peaks/peak_proximity.py new file mode 100644 index 00000000..8a8a36df --- /dev/null +++ b/amstrax/plugins/peaks/peak_proximity.py @@ -0,0 +1,92 @@ +import numpy as np +import numba +import strax +import amstrax + + +export, __all__ = strax.exporter() + + +@export +@strax.takes_config( + strax.Option('min_area_fraction', default=0.5, + help='The area of competing peaks must be at least ' + 'this fraction of that of the considered peak'), + strax.Option('nearby_window', default=int(1e6), + help='Peaks starting within this time window (on either side)' + 'in ns count as nearby.'), + strax.Option('peak_max_proximity_time', default=int(1e8), + help='Maximum value for proximity values such as t_to_next_peak [ns]'), + +) +class PeakProximity(strax.OverlapWindowPlugin): + """ + Look for peaks around a peak to determine how many peaks are in + proximity (in time) of a peak. + """ + __version__ = '0.4.0' + + depends_on = ('peak_basics',) + dtype = [ + ('n_competing', np.int32, + 'Number of nearby larger or slightly smaller peaks'), + ('n_competing_left', np.int32, + 'Number of larger or slightly smaller peaks left of the main peak'), + ('t_to_prev_peak', np.int64, + 'Time between end of previous peak and start of this peak [ns]'), + ('t_to_next_peak', np.int64, + 'Time between end of this peak and start of next peak [ns]'), + ('t_to_nearest_peak', np.int64, + 'Smaller of t_to_prev_peak and t_to_next_peak [ns]') + ] + strax.time_fields + + + def setup(self): + + self.min_area_fraction = self.config['min_area_fraction'] + self.nearby_window = self.config['nearby_window'] + self.peak_max_proximity_time = self.config['peak_max_proximity_time'] + + + def get_window_size(self): + return self.peak_max_proximity_time + + def compute(self, peaks): + windows = strax.touching_windows(peaks, peaks, + window=self.nearby_window) + n_left, n_tot = self.find_n_competing( + peaks, + windows, + fraction=self.min_area_fraction) + + t_to_prev_peak = ( + np.ones(len(peaks), dtype=np.int64) + * self.peak_max_proximity_time) + t_to_prev_peak[1:] = peaks['time'][1:] - peaks['endtime'][:-1] + + t_to_next_peak = t_to_prev_peak.copy() + t_to_next_peak[:-1] = peaks['time'][1:] - peaks['endtime'][:-1] + + return dict( + time=peaks['time'], + endtime=strax.endtime(peaks), + n_competing=n_tot, + n_competing_left=n_left, + t_to_prev_peak=t_to_prev_peak, + t_to_next_peak=t_to_next_peak, + t_to_nearest_peak=np.minimum(t_to_prev_peak, t_to_next_peak)) + + @staticmethod + @numba.jit(nopython=True, nogil=True, cache=True) + def find_n_competing(peaks, windows, fraction): + n_left = np.zeros(len(peaks), dtype=np.int32) + n_tot = n_left.copy() + areas = peaks['area'] + + for i, peak in enumerate(peaks): + left_i, right_i = windows[i] + threshold = areas[i] * fraction + n_left[i] = np.sum(areas[left_i:i] > threshold) + n_tot[i] = n_left[i] + np.sum(areas[i + 1:right_i] > threshold) + + return n_left, n_tot diff --git a/amstrax/plugins/records/pulse_processing.py b/amstrax/plugins/records/pulse_processing.py index 9edfbd1b..f9b56922 100644 --- a/amstrax/plugins/records/pulse_processing.py +++ b/amstrax/plugins/records/pulse_processing.py @@ -78,7 +78,10 @@ class PulseProcessing(strax.Plugin): provides = ('records', 'pulse_counts') data_kind = {k: k for k in provides} - save_when = strax.SaveWhen.TARGET + + # I think in amstrax we can save everything + # default is ALWAYS + # save_when = strax.SaveWhen.TARGET def infer_dtype(self,): # The record_length is the same for both raw_records_v1724 and raw_records_v1730 diff --git a/requirements.txt b/requirements.txt index c2497e20..c757681d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ numpy numba>=0.50.0 strax>=0.12.0 pymongo<4.0 +multihist>=0.6.3 IPython sshtunnel iminuit diff --git a/tests/test_basics.py b/tests/test_basics.py index a2c89860..ec1aeaa5 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -54,13 +54,12 @@ def test_make(self): self.get_test_data() run_id = self.run_id for target, plugin_class in self.st._plugin_class_registry.items(): - if target in ['raw_records', 'records', 'pulse_counts', 'peaks', 'peak_classification', 'peak_basics', 'events', 'event_basics']: - print('>>>>>>>>>>> Making', target) - self.st.make(run_id, target) - data = self.st.get_array(run_id, target) - print(len(data), 'entries in', target) - if plugin_class.save_when >= strax.SaveWhen.TARGET: - assert self.st.is_stored(run_id, target) + print('>>>>>>>>>>> Making', target) + self.st.make(run_id, target) + data = self.st.get_array(run_id, target) + print(len(data), 'entries in', target) + if plugin_class.save_when >= strax.SaveWhen.TARGET: + assert self.st.is_stored(run_id, target) with self.assertRaises(ValueError): # Now since we have the 'raw' data, we cannot be allowed to # make it again!