diff --git a/.buildconfig/ci-linux.yml b/.buildconfig/ci-linux.yml index 650237d55..8b056faf4 100644 --- a/.buildconfig/ci-linux.yml +++ b/.buildconfig/ci-linux.yml @@ -10,36 +10,36 @@ channels: - nodefaults dependencies: - h5py==3.12.1 - - hypothesis==6.119.4 + - hypothesis==6.123.17 - ipykernel==6.29.5 - - ipympl==0.9.4 + - ipympl==0.9.6 - ipywidgets==8.1.5 - mantid==6.11.0 - - matplotlib==3.7.3 + - matplotlib==3.7.3 # Note that this version is out of sync with requirements because of Mantid - mpltoolbox==24.05.1 - plopp==24.10.0 - pooch==1.8.2 - - pytest==8.3.3 + - pytest==8.3.4 - pytest-asyncio==0.24.0 - python-graphviz==0.20.3 - pythreejs==2.4.2 - - scipp==24.11.1 - - scippnexus==24.11.0 - - scipy==1.13 + - scipp==24.11.2 + - scippnexus==24.11.1 + - scipy==1.15.1 - tox==4.23.2 # docs - myst-parser==4.0.0 - - nbsphinx==0.9.5 - - packaging==24.1 + - nbsphinx==0.9.6 + - packaging==24.2 - pandoc==3.4.0 - - pydata-sphinx-theme==0.16.0 + - pydata-sphinx-theme==0.16.1 - sphinx==8.1.3 - - sphinx-autodoc-typehints==2.5.0 + - sphinx-autodoc-typehints==3.0.0 - sphinx-copybutton==0.5.2 - sphinx-design==0.6.1 - sphinxcontrib-bibtex==2.6.3 - - tof==24.12.0 + - tof==25.1.2 # docs and tests - sciline==24.10.0 diff --git a/docs/user-guide/chopper/frame-unwrapping.ipynb b/docs/user-guide/chopper/frame-unwrapping.ipynb index de22c6d10..7f65b3183 100644 --- a/docs/user-guide/chopper/frame-unwrapping.ipynb +++ b/docs/user-guide/chopper/frame-unwrapping.ipynb @@ -29,15 +29,12 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import plopp as pp\n", "import scipp as sc\n", "import sciline as sl\n", + "from scippneutron.chopper import DiskChopper\n", "from scippneutron.tof import unwrap\n", - "from scippneutron.tof import chopper_cascade\n", "import tof as tof_pkg\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib.patches import Polygon\n", "\n", "Hz = sc.Unit(\"Hz\")\n", "deg = sc.Unit(\"deg\")\n", @@ -73,12 +70,12 @@ "metadata": {}, "outputs": [], "source": [ - "source = tof_pkg.Source(facility=\"ess\", neutrons=300_000, pulses=5)\n", + "source = tof_pkg.Source(facility=\"ess\", pulses=5)\n", "chopper = tof_pkg.Chopper(\n", " frequency=14.0 * Hz,\n", " open=sc.array(dims=[\"cutout\"], values=[0.0], unit=\"deg\"),\n", " close=sc.array(dims=[\"cutout\"], values=[3.0], unit=\"deg\"),\n", - " phase=85. * deg,\n", + " phase=85.0 * deg,\n", " distance=8.0 * meter,\n", " name=\"chopper\",\n", ")\n", @@ -91,16 +88,22 @@ "\n", "model = tof_pkg.Model(source=source, choppers=[chopper], detectors=detectors)\n", "results = model.run()\n", - "pl = results.plot(cmap='viridis_r')\n", + "pl = results.plot(cmap=\"viridis_r\")\n", "\n", "for i in range(source.pulses):\n", - " pl.ax.axvline(i * (1.0 / source.frequency).to(unit='us').value, color='k', ls='dotted')\n", - " x = [results[det.name].toas.data['visible'][f'pulse:{i}'].coords['toa'].min().value\n", - " for det in detectors]\n", + " pl.ax.axvline(\n", + " i * (1.0 / source.frequency).to(unit=\"us\").value, color=\"k\", ls=\"dotted\"\n", + " )\n", + " x = [\n", + " results[det.name].toas.data[\"visible\"][f\"pulse:{i}\"].coords[\"toa\"].min().value\n", + " for det in detectors\n", + " ]\n", " y = [det.distance.value for det in detectors]\n", - " pl.ax.plot(x, y, '--o', color='magenta', lw=3)\n", + " pl.ax.plot(x, y, \"--o\", color=\"magenta\", lw=3)\n", " if i == 0:\n", - " pl.ax.text(x[2], y[2] * 1.05, \"pivot time\", va='bottom', ha='right', color='magenta')" + " pl.ax.text(\n", + " x[2], y[2] * 1.05, \"pivot time\", va=\"bottom\", ha=\"right\", color=\"magenta\"\n", + " )" ] }, { @@ -131,13 +134,21 @@ "outputs": [], "source": [ "subplots = pp.tiled(2, 2, figsize=(9, 6))\n", + "nxevent_data = results.to_nxevent_data()\n", "for i, det in enumerate(detectors):\n", - " data = results.to_nxevent_data(det.name)\n", - " subplots[i // 2, i % 2] = data.bins.concat().hist(event_time_offset=200).plot(title=f'{det.name}={det.distance:c}', color=f'C{i}')\n", + " data = nxevent_data[\"detector_number\", i]\n", + " subplots[i // 2, i % 2] = (\n", + " data.bins.concat()\n", + " .hist(event_time_offset=200)\n", + " .plot(title=f\"{det.name}={det.distance:c}\", color=f\"C{i}\")\n", + " )\n", " f = subplots[i // 2, i % 2]\n", - " xpiv = min(da.coords['toa'].min() % (1.0 / source.frequency).to(unit='us') for da in results[det.name].toas.data['visible'].values()).value\n", - " f.ax.axvline(xpiv, ls='dashed', color='magenta', lw=2)\n", - " f.ax.text(xpiv, 20, 'pivot time', rotation=90, color='magenta')\n", + " xpiv = min(\n", + " da.coords[\"toa\"].min() % (1.0 / source.frequency).to(unit=\"us\")\n", + " for da in results[det.name].toas.data[\"visible\"].values()\n", + " ).value\n", + " f.ax.axvline(xpiv, ls=\"dashed\", color=\"magenta\", lw=2)\n", + " f.ax.text(xpiv, 20, \"pivot time\", rotation=90, color=\"magenta\")\n", " f.canvas.draw()\n", "subplots" ] @@ -182,29 +193,27 @@ "metadata": {}, "outputs": [], "source": [ - "one_pulse = source.data['pulse', 0]\n", - "time_min = one_pulse.coords['time'].min()\n", - "time_max = one_pulse.coords['time'].max()\n", - "wavs_min = one_pulse.coords['wavelength'].min()\n", - "wavs_max = one_pulse.coords['wavelength'].max()\n", - "oc_times = chopper.open_close_times()\n", - "\n", "workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())\n", "\n", - "workflow[unwrap.PulsePeriod] = sc.reciprocal(source.frequency)\n", - "workflow[unwrap.SourceTimeRange] = time_min, time_max\n", - "workflow[unwrap.SourceWavelengthRange] = wavs_min, wavs_max\n", + "workflow[unwrap.Facility] = \"ess\"\n", "workflow[unwrap.Choppers] = {\n", - " 'chopper': chopper_cascade.Chopper(\n", - " distance=chopper.distance,\n", - " time_open=oc_times[0].to(unit='s'),\n", - " time_close=oc_times[1].to(unit='s')\n", - " )\n", + " \"chopper\": DiskChopper(\n", + " frequency=-chopper.frequency,\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=-chopper.phase,\n", + " axle_position=sc.vector(\n", + " value=[0, 0, chopper.distance.value], unit=chopper.distance.unit\n", + " ),\n", + " slit_begin=chopper.open,\n", + " slit_end=chopper.close,\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + " )\n", "}\n", "\n", - "det = detectors[2]\n", - "workflow[unwrap.Ltotal] = det.distance\n", - "workflow[unwrap.RawData] = results.to_nxevent_data(det.name)\n", + "workflow[unwrap.RawData] = nxevent_data\n", + "workflow[unwrap.Ltotal] = nxevent_data.coords[\"Ltotal\"]\n", + "workflow[unwrap.DistanceResolution] = sc.scalar(0.1, unit=\"m\")\n", "\n", "workflow.visualize(unwrap.TofData)" ] @@ -227,7 +236,9 @@ "metadata": {}, "outputs": [], "source": [ - "da = workflow.compute(unwrap.UnwrappedTimeOfArrival)\n", + "da = workflow.compute(unwrap.UnwrappedTimeOfArrival)[\n", + " \"detector_number\", 2\n", + "] # Look at a single detector\n", "da.bins.concat().value.hist(time_of_arrival=300).plot()" ] }, @@ -251,10 +262,12 @@ "metadata": {}, "outputs": [], "source": [ - "da = workflow.compute(unwrap.UnwrappedTimeOfArrivalMinusStartTime)\n", + "da = workflow.compute(unwrap.UnwrappedTimeOfArrivalMinusStartTime)[\"detector_number\", 2]\n", "f = da.bins.concat().value.hist(time_of_arrival=300).plot()\n", "for i in range(source.pulses):\n", - " f.ax.axvline(i * (1.0 / source.frequency).to(unit='us').value, color='k', ls='dotted')\n", + " f.ax.axvline(\n", + " i * (1.0 / source.frequency).to(unit=\"us\").value, color=\"k\", ls=\"dotted\"\n", + " )\n", "f" ] }, @@ -267,7 +280,9 @@ "\n", "#### Unwrapped neutron time-of-arrival modulo the frame period\n", "\n", - "We now wrap the arrival times with the frame period to obtain well formed (unbroken) set of events." + "We now wrap the arrival times with the frame period to obtain well formed (unbroken) set of events.\n", + "\n", + "We also re-add the pivot time offset we had subtracted earlier (to enable to modulo operation)." ] }, { @@ -277,7 +292,7 @@ "metadata": {}, "outputs": [], "source": [ - "da = workflow.compute(unwrap.TimeOfArrivalMinusStartTimeModuloPeriod)\n", + "da = workflow.compute(unwrap.FrameFoldedTimeOfArrival)[\"detector_number\", 2]\n", "da.bins.concat().value.hist(time_of_arrival=200).plot()" ] }, @@ -286,125 +301,68 @@ "id": "16", "metadata": {}, "source": [ - "#### Using the subframes as a lookup table\n", + "#### Create a lookup table\n", "\n", "The chopper information is next used to construct a lookup table that provides an estimate of the real time-of-flight as a function of time-of-arrival.\n", "\n", - "The `chopper_cascade` module can be used to propagate the pulse through the chopper system to the detector,\n", - "and predict the extent of the frames in arrival time and wavelength.\n", + "The `tof` module can be used to propagate a pulse of neutrons through the chopper system to the detectors,\n", + "and predict the most likely neutron wavelength for a given time-of-arrival.\n", "\n", - "Assuming neutrons travel in straight lines, we can convert the wavelength range to a time-of-flight range." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(9, 3))\n", + "We typically have hundreds of thousands of pixels in an instrument,\n", + "but it is actually not necessary to propagate the neutrons to 105 detectors.\n", "\n", - "# The chopper cascade frame (at the last component in the system which is the chopper)\n", - "frame = workflow.compute(unwrap.ChopperCascadeFrames)[0][-1]\n", + "Instead, we make a table that spans the entire range of distances of all the pixels,\n", + "with a modest resolution,\n", + "and use a linear interpolation for values that lie between the points in the table.\n", "\n", - "# Plot the ranges covered by the neutrons at each detector as polygons\n", - "polygons = []\n", - "for i, det in enumerate(detectors):\n", - " at_detector = frame.propagate_to(det.distance)\n", - " for sf in at_detector.subframes:\n", - " x = sf.time\n", - " w = sf.wavelength\n", - " t = det.distance * chopper_cascade.wavelength_to_inverse_velocity(w)\n", - " axs[0].add_patch(Polygon(np.array([x.values, w.values]).T, color=f\"C{i}\", alpha=0.8))\n", - " axs[1].add_patch(Polygon(np.array([x.values, t.values]).T, color=f\"C{i}\", alpha=0.8))\n", - " axs[0].text(x.min().value, w.min().value, det.name, va='top', ha='left')\n", - " axs[1].text(x.min().value, t.min().value, det.name, va='top', ha='left')\n", - "\n", - "for ax in axs:\n", - " ax.autoscale()\n", - " ax.set_xlabel(\"Time of arrival [s]\")\n", - "axs[0].set_ylabel(r\"Wavelength [\\AA]\")\n", - "axs[1].set_ylabel(r\"Time of flight [s]\")\n", - "fig.suptitle(\"Lookup tables for wavelength and time-of-flight as a function of time-of-arrival\")\n", - "fig.tight_layout()" - ] - }, - { - "cell_type": "markdown", - "id": "18", - "metadata": {}, - "source": [ - "Since the polygons are very thin, we can **approximate them with straight lines**.\n", - "This is done by using a least-squares method which minimizes the area on each side of the line that passes through the polygon\n", - "(see https://mathproblems123.wordpress.com/2022/09/13/integrating-polynomials-on-polygons for more details).\n", + "To create the table, we thus:\n", "\n", - "These straight lines will be our lookup tables for computing time-of-flight as a function of time-of-arrival\n", - "(we need two lookup tables: one for the slope of each line and one for the intercept)." + "- run a simulation where a pulse of neutrons passes through the choppers and reaches the sample (or any location after the last chopper)\n", + "- propagate the neutrons from the sample to a range of distances that span the minimum and maximum pixel distance from the sample (assuming neutron wavelengths do not change)\n", + "- bin the neutrons in both distance and time-of-arrival (yielding a 2D binned data array)\n", + "- compute the (weighted) mean wavelength inside each bin\n", + "- convert the wavelengths to a real time-of-flight to give our final lookup table\n", + "\n", + "The table can be visualized here:" ] }, { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "17", "metadata": {}, "outputs": [], "source": [ - "fig, ax = plt.subplots(2, 2, figsize=(9, 6))\n", - "\n", - "axs = ax.ravel()\n", - "for i, det in enumerate(detectors):\n", - " workflow[unwrap.Ltotal] = det.distance\n", - " at_detector = workflow.compute(unwrap.FrameAtDetector)\n", - " start = workflow.compute(unwrap.FrameAtDetectorStartTime)\n", - " toa2tof = workflow.compute(unwrap.TimeOfArrivalToTimeOfFlight)\n", - " for sf in at_detector.subframes:\n", - " x = sf.time\n", - " y = det.distance * chopper_cascade.wavelength_to_inverse_velocity(sf.wavelength)\n", - " axs[i].add_patch(Polygon(np.array([x.values, y.values]).T, color=f\"C{i}\", alpha=0.8))\n", - " x = toa2tof.slope.coords['subframe']\n", - " y = toa2tof.slope.squeeze() * x + toa2tof.intercept.squeeze()\n", - " axs[i].plot((x+start).values, y.values, color='k', ls='dashed')\n", - " axs[i].set_xlabel(\"Time of arrival [s]\")\n", - " axs[i].set_ylabel(\"Time of flight [s]\")\n", - " axs[i].set_title(f'{det.name}={det.distance:c}')\n", - "fig.suptitle(\"Approximating the polygons with straight lines\")\n", - "fig.tight_layout()" + "table = workflow.compute(unwrap.TimeOfFlightLookupTable)\n", + "table.plot()" ] }, { "cell_type": "markdown", - "id": "20", + "id": "18", "metadata": {}, "source": [ "#### Computing time-of-flight from the lookup\n", "\n", - "Now that we have a slope and an intercept for the frames at each detector,\n", - "we can compute the time-of-flight of the neutrons." + "We now use the above table to perform a bilinear interpolation and compute the time-of-flight of every neutron." ] }, { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "19", "metadata": {}, "outputs": [], "source": [ - "tofs = {}\n", - "\n", - "for det in detectors:\n", - " workflow[unwrap.RawData] = results.to_nxevent_data(det.name)\n", - " workflow[unwrap.Ltotal] = det.distance\n", - " t = workflow.compute(unwrap.TofData)\n", - " tofs[det.name] = t.bins.concat().value.hist(tof=sc.scalar(500., unit='us'))\n", - " tofs[det.name].coords['Ltotal'] = t.coords['distance']\n", + "tofs = workflow.compute(unwrap.TofData)\n", "\n", - "pp.plot(tofs)" + "tof_hist = tofs.bins.concat(\"pulse\").hist(tof=sc.scalar(500.0, unit=\"us\"))\n", + "pp.plot({det.name: tof_hist[\"detector_number\", i] for i, det in enumerate(detectors)})" ] }, { "cell_type": "markdown", - "id": "22", + "id": "20", "metadata": {}, "source": [ "### Converting to wavelength\n", @@ -418,7 +376,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -431,24 +389,26 @@ "# Define wavelength bin edges\n", "bins = sc.linspace(\"wavelength\", 6.0, 9.0, 101, unit=\"angstrom\")\n", "\n", - "wavs = {}\n", - "for det in detectors:\n", - " workflow[unwrap.RawData] = results.to_nxevent_data(det.name)\n", - " workflow[unwrap.Ltotal] = det.distance\n", - " t = workflow.compute(unwrap.TofData)\n", - " t.coords['Ltotal'] = t.coords.pop('distance')\n", - " wavs[det.name] = t.transform_coords(\"wavelength\", graph=graph).bins.concat().hist(wavelength=bins)\n", + "# Compute wavelengths\n", + "wav_hist = (\n", + " tofs.transform_coords(\"wavelength\", graph=graph)\n", + " .bins.concat(\"pulse\")\n", + " .hist(wavelength=bins)\n", + ")\n", + "wavs = {det.name: wav_hist[\"detector_number\", i] for i, det in enumerate(detectors)}\n", "\n", "ground_truth = results[\"detector\"].data.flatten(to=\"event\")\n", - "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]].hist(wavelength=bins)\n", + "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]].hist(\n", + " wavelength=bins\n", + ")\n", "\n", - "wavs['true'] = ground_truth\n", + "wavs[\"true\"] = ground_truth\n", "pp.plot(wavs)" ] }, { "cell_type": "markdown", - "id": "24", + "id": "22", "metadata": {}, "source": [ "We see that all detectors agree on the wavelength spectrum,\n", @@ -469,17 +429,17 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "23", "metadata": {}, "outputs": [], "source": [ - "source = tof_pkg.Source(facility=\"ess\", neutrons=300_000, pulses=4)\n", + "source = tof_pkg.Source(facility=\"ess\", pulses=4)\n", "choppers = [\n", " tof_pkg.Chopper(\n", " frequency=14.0 * Hz,\n", " open=sc.array(dims=[\"cutout\"], values=[0.0], unit=\"deg\"),\n", " close=sc.array(dims=[\"cutout\"], values=[33.0], unit=\"deg\"),\n", - " phase=35. * deg,\n", + " phase=35.0 * deg,\n", " distance=8.0 * meter,\n", " name=\"chopper\",\n", " ),\n", @@ -487,10 +447,10 @@ " frequency=7.0 * Hz,\n", " open=sc.array(dims=[\"cutout\"], values=[0.0], unit=\"deg\"),\n", " close=sc.array(dims=[\"cutout\"], values=[120.0], unit=\"deg\"),\n", - " phase=10. * deg,\n", + " phase=10.0 * deg,\n", " distance=15.0 * meter,\n", " name=\"pulse-skipping\",\n", - " )\n", + " ),\n", "]\n", "detectors = [\n", " tof_pkg.Detector(distance=60.0 * meter, name=\"monitor\"),\n", @@ -499,12 +459,12 @@ "\n", "model = tof_pkg.Model(source=source, choppers=choppers, detectors=detectors)\n", "results = model.run()\n", - "results.plot(cmap='viridis_r', blocked_rays=5000)" + "results.plot(cmap=\"viridis_r\", blocked_rays=5000)" ] }, { "cell_type": "markdown", - "id": "26", + "id": "24", "metadata": {}, "source": [ "### Computing time-of-flight\n", @@ -518,78 +478,59 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "25", "metadata": {}, "outputs": [], "source": [ - "one_pulse = source.data['pulse', 0]\n", - "time_min = one_pulse.coords['time'].min()\n", - "time_max = one_pulse.coords['time'].max()\n", - "wavs_min = one_pulse.coords['wavelength'].min()\n", - "wavs_max = one_pulse.coords['wavelength'].max()\n", - "\n", "workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())\n", - "workflow[unwrap.PulsePeriod] = sc.reciprocal(source.frequency)\n", + "\n", + "workflow[unwrap.Facility] = \"ess\"\n", "workflow[unwrap.PulseStride] = 2\n", - "workflow[unwrap.SourceTimeRange] = time_min, time_max\n", - "workflow[unwrap.SourceWavelengthRange] = wavs_min, wavs_max\n", "workflow[unwrap.Choppers] = {\n", - " ch.name: chopper_cascade.Chopper(\n", - " distance=ch.distance,\n", - " time_open=ch.open_close_times()[0].to(unit='s'),\n", - " time_close=ch.open_close_times()[1].to(unit='s')\n", - " )\n", + " ch.name: DiskChopper(\n", + " frequency=-ch.frequency,\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=-ch.phase,\n", + " axle_position=sc.vector(\n", + " value=[0, 0, ch.distance.value], unit=chopper.distance.unit\n", + " ),\n", + " slit_begin=ch.open,\n", + " slit_end=ch.close,\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + " )\n", " for ch in choppers\n", "}\n", "\n", - "det = detectors[-1]\n", - "workflow[unwrap.Ltotal] = det.distance\n", - "workflow[unwrap.RawData] = results.to_nxevent_data(det.name)" + "nxevent_data = results.to_nxevent_data()\n", + "workflow[unwrap.RawData] = nxevent_data\n", + "workflow[unwrap.Ltotal] = nxevent_data.coords[\"Ltotal\"]\n", + "workflow[unwrap.DistanceResolution] = sc.scalar(0.5, unit=\"m\")" ] }, { "cell_type": "markdown", - "id": "28", + "id": "26", "metadata": {}, "source": [ - "If we inspect the time and wavelength polygons for the frames at the different detectors,\n", - "we can see that they now span longer than the pulse period of 71 ms." + "If we inspect the time-of-flight lookup table,\n", + "we can see that the time-of-arrival (toa) dimension now spans longer than the pulse period of 71 ms." ] }, { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "27", "metadata": {}, "outputs": [], "source": [ - "frame = workflow.compute(unwrap.ChopperCascadeFrames)[0][-1]\n", - "\n", - "\n", - "fig, axs = plt.subplots(1, 2, figsize=(9, 3))\n", - "\n", - "polygons = []\n", - "for i, det in enumerate(detectors):\n", - " at_detector = frame.propagate_to(det.distance)\n", - " for sf in at_detector.subframes:\n", - " x = sf.time\n", - " w = sf.wavelength\n", - " t = det.distance * chopper_cascade.wavelength_to_inverse_velocity(w)\n", - " axs[0].add_patch(Polygon(np.array([x.values, w.values]).T, color=f\"C{i}\", alpha=0.8))\n", - " axs[1].add_patch(Polygon(np.array([x.values, t.values]).T, color=f\"C{i}\", alpha=0.8))\n", - "\n", - "for ax in axs:\n", - " ax.autoscale()\n", - " ax.set_xlabel(\"Time of arrival [s]\")\n", - "axs[0].set_ylabel(r\"Wavelength [\\AA]\")\n", - "axs[1].set_ylabel(r\"Time of flight [s]\")\n", - "fig.suptitle(\"Lookup tables for wavelength and time-of-flight as a function of time-of-arrival\")\n", - "fig.tight_layout()" + "table = workflow.compute(unwrap.TimeOfFlightLookupTable)\n", + "table.plot()" ] }, { "cell_type": "markdown", - "id": "30", + "id": "28", "metadata": {}, "source": [ "The time-of-flight profiles are then:" @@ -598,25 +539,19 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "29", "metadata": {}, "outputs": [], "source": [ - "tofs = {}\n", - "\n", - "for det in detectors:\n", - " workflow[unwrap.RawData] = results.to_nxevent_data(det.name)\n", - " workflow[unwrap.Ltotal] = det.distance\n", - " t = workflow.compute(unwrap.TofData)\n", - " tofs[det.name] = t.bins.concat().value.hist(tof=sc.scalar(500., unit='us'))\n", - " tofs[det.name].coords['Ltotal'] = t.coords['distance']\n", + "tofs = workflow.compute(unwrap.TofData)\n", "\n", - "pp.plot(tofs)" + "tof_hist = tofs.bins.concat(\"pulse\").hist(tof=sc.scalar(500.0, unit=\"us\"))\n", + "pp.plot({det.name: tof_hist[\"detector_number\", i] for i, det in enumerate(detectors)})" ] }, { "cell_type": "markdown", - "id": "32", + "id": "30", "metadata": {}, "source": [ "### Conversion to wavelength\n", @@ -627,33 +562,33 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "31", "metadata": {}, "outputs": [], "source": [ "# Define wavelength bin edges\n", "bins = sc.linspace(\"wavelength\", 1.0, 8.0, 401, unit=\"angstrom\")\n", "\n", - "wavs = {}\n", - "for det in detectors:\n", - " workflow[unwrap.RawData] = results.to_nxevent_data(det.name)\n", - " workflow[unwrap.Ltotal] = det.distance\n", - " t = workflow.compute(unwrap.TofData)\n", - " t.coords['Ltotal'] = t.coords.pop('distance')\n", - " wavs[det.name] = t.transform_coords(\"wavelength\", graph=graph).bins.concat().hist(wavelength=bins)\n", + "# Compute wavelengths\n", + "wav_hist = (\n", + " tofs.transform_coords(\"wavelength\", graph=graph)\n", + " .bins.concat(\"pulse\")\n", + " .hist(wavelength=bins)\n", + ")\n", + "wavs = {det.name: wav_hist[\"detector_number\", i] for i, det in enumerate(detectors)}\n", "\n", "ground_truth = results[\"detector\"].data.flatten(to=\"event\")\n", - "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]].hist(wavelength=bins)\n", - "\n", - "wavs['true'] = ground_truth\n", - "\n", + "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]].hist(\n", + " wavelength=bins\n", + ")\n", "\n", + "wavs[\"true\"] = ground_truth\n", "pp.plot(wavs)" ] }, { "cell_type": "markdown", - "id": "34", + "id": "32", "metadata": {}, "source": [ "## Wavelength-frame multiplication mode\n", @@ -669,30 +604,30 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "33", "metadata": {}, "outputs": [], "source": [ - "source = tof_pkg.Source(facility=\"ess\", neutrons=500_000, pulses=2)\n", + "source = tof_pkg.Source(facility=\"ess\", pulses=2)\n", "\n", - "slit_width = 3.\n", - "open_edge = sc.linspace('cutout', 0., 75, 6, unit='deg')\n", + "slit_width = 3.0\n", + "open_edge = sc.linspace(\"cutout\", 0.0, 75, 6, unit=\"deg\")\n", "wfm = tof_pkg.Chopper(\n", " frequency=14.0 * Hz,\n", " open=open_edge,\n", " close=open_edge + slit_width * deg,\n", - " phase=45. * deg,\n", + " phase=45.0 * deg,\n", " distance=8.0 * meter,\n", " name=\"WFM\",\n", ")\n", "\n", - "slit_width = 25.\n", - "open_edge = sc.linspace('cutout', 0., 190, 6, unit='deg')\n", + "slit_width = 25.0\n", + "open_edge = sc.linspace(\"cutout\", 0.0, 190, 6, unit=\"deg\")\n", "foc = tof_pkg.Chopper(\n", " frequency=14.0 * Hz,\n", " open=open_edge,\n", " close=open_edge + slit_width * deg,\n", - " phase=85. * deg,\n", + " phase=85.0 * deg,\n", " distance=20.0 * meter,\n", " name=\"FOC\",\n", ")\n", @@ -704,12 +639,12 @@ "\n", "model = tof_pkg.Model(source=source, choppers=choppers, detectors=detectors)\n", "results = model.run()\n", - "results.plot(cmap='viridis_r', blocked_rays=5000)" + "results.plot(cmap=\"viridis_r\", blocked_rays=5000)" ] }, { "cell_type": "markdown", - "id": "36", + "id": "34", "metadata": {}, "source": [ "The signal of the raw `event_time_offset` at the detector wras around the 71 ms mark:" @@ -718,16 +653,16 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "35", "metadata": {}, "outputs": [], "source": [ - "results.to_nxevent_data(det.name).bins.concat().hist(event_time_offset=300).plot()" + "results.to_nxevent_data().bins.concat().hist(event_time_offset=300).plot()" ] }, { "cell_type": "markdown", - "id": "38", + "id": "36", "metadata": {}, "source": [ "### Setting up the workflow\n", @@ -741,63 +676,60 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "37", "metadata": {}, "outputs": [], "source": [ - "one_pulse = source.data['pulse', 0]\n", - "time_min = one_pulse.coords['time'].min()\n", - "time_max = one_pulse.coords['time'].max()\n", - "wavs_min = one_pulse.coords['wavelength'].min()\n", - "wavs_max = one_pulse.coords['wavelength'].max()\n", - "\n", "workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())\n", "\n", - "workflow[unwrap.PulsePeriod] = sc.reciprocal(source.frequency)\n", - "workflow[unwrap.SourceTimeRange] = time_min, time_max\n", - "workflow[unwrap.SourceWavelengthRange] = wavs_min, wavs_max\n", + "workflow[unwrap.Facility] = \"ess\"\n", "workflow[unwrap.Choppers] = {\n", - " ch.name: chopper_cascade.Chopper(\n", - " distance=ch.distance,\n", - " time_open=ch.open_close_times()[0].to(unit='s'),\n", - " time_close=ch.open_close_times()[1].to(unit='s')\n", - " )\n", + " ch.name: DiskChopper(\n", + " frequency=-ch.frequency,\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=-ch.phase,\n", + " axle_position=sc.vector(\n", + " value=[0, 0, ch.distance.value], unit=chopper.distance.unit\n", + " ),\n", + " slit_begin=ch.open,\n", + " slit_end=ch.close,\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + " )\n", " for ch in choppers\n", "}\n", "\n", - "det = detectors[-1]\n", - "workflow[unwrap.Ltotal] = det.distance\n", - "workflow[unwrap.RawData] = results.to_nxevent_data(det.name)" + "nxevent_data = results.to_nxevent_data()\n", + "workflow[unwrap.RawData] = nxevent_data\n", + "workflow[unwrap.Ltotal] = nxevent_data.coords[\"Ltotal\"]\n", + "workflow[unwrap.DistanceResolution] = sc.scalar(0.5, unit=\"m\")" ] }, { "cell_type": "markdown", - "id": "40", + "id": "38", "metadata": {}, "source": [ - "At this point it is useful to look at the propagation of the pulse through the chopper cascade:" + "This time, the lookup table has 6 distinct bands separated by empty space,\n", + "corresponding to the 6 WFM frames from the 6 chopper openings." ] }, { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "39", "metadata": {}, "outputs": [], "source": [ - "frames = workflow.compute(unwrap.ChopperCascadeFrames)[0]\n", - "at_detector = frames.propagate_to(det.distance)\n", - "at_detector.draw()" + "table = workflow.compute(unwrap.TimeOfFlightLookupTable)\n", + "table.plot()" ] }, { "cell_type": "markdown", - "id": "42", + "id": "40", "metadata": {}, "source": [ - "It illustrates nicely the pulse being chopped into 6 pieces,\n", - "corresponding to the 6 openings of the choppers.\n", - "\n", "### Computing time-of-flight\n", "\n", "The time-of-flight profile resembles the `event_time_offset` profile above,\n", @@ -807,7 +739,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "41", "metadata": {}, "outputs": [], "source": [ @@ -817,7 +749,7 @@ }, { "cell_type": "markdown", - "id": "44", + "id": "42", "metadata": {}, "source": [ "### Conversion to wavelength\n", @@ -828,20 +760,23 @@ { "cell_type": "code", "execution_count": null, - "id": "45", + "id": "43", "metadata": {}, "outputs": [], "source": [ "# Define wavelength bin edges\n", "bins = sc.linspace(\"wavelength\", 2.0, 12.0, 401, unit=\"angstrom\")\n", "\n", - "tofs.coords['Ltotal'] = tofs.coords.pop('distance')\n", - "wavs = tofs.transform_coords(\"wavelength\", graph=graph).bins.concat().hist(wavelength=bins)\n", + "wavs = (\n", + " tofs.transform_coords(\"wavelength\", graph=graph).bins.concat().hist(wavelength=bins)\n", + ")\n", "\n", "ground_truth = results[\"detector\"].data.flatten(to=\"event\")\n", - "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]].hist(wavelength=bins)\n", + "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]].hist(\n", + " wavelength=bins\n", + ")\n", "\n", - "pp.plot({'wfm': wavs, 'true': ground_truth})" + "pp.plot({\"wfm\": wavs, \"true\": ground_truth})" ] } ], diff --git a/docs/user-guide/wfm/dream-wfm.ipynb b/docs/user-guide/wfm/dream-wfm.ipynb index e9e03df95..23a41983d 100644 --- a/docs/user-guide/wfm/dream-wfm.ipynb +++ b/docs/user-guide/wfm/dream-wfm.ipynb @@ -22,8 +22,7 @@ "import scipp as sc\n", "import sciline as sl\n", "from scippneutron.chopper import DiskChopper\n", - "from scippneutron.tof import unwrap\n", - "from scippneutron.tof import chopper_cascade" + "from scippneutron.tof import unwrap" ] }, { @@ -167,33 +166,6 @@ "cell_type": "markdown", "id": "8", "metadata": {}, - "source": [ - "### Convert the choppers\n", - "\n", - "Lastly, we convert our disk choppers to a simpler chopper representation used by the `chopper_cascade` module. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "choppers = {\n", - " key: chopper_cascade.Chopper.from_disk_chopper(\n", - " chop,\n", - " pulse_frequency=sc.scalar(14.0, unit=\"Hz\"),\n", - " npulses=1,\n", - " )\n", - " for key, chop in disk_choppers.items()\n", - "}" - ] - }, - { - "cell_type": "markdown", - "id": "10", - "metadata": {}, "source": [ "## Creating some neutron events\n", "\n", @@ -203,14 +175,14 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "9", "metadata": {}, "outputs": [], "source": [ "from scippneutron.tof.fakes import FakeBeamlineEss\n", "\n", "ess_beamline = FakeBeamlineEss(\n", - " choppers=choppers,\n", + " choppers=disk_choppers,\n", " monitors={\"detector\": Ltotal},\n", " run_length=sc.scalar(1 / 14, unit=\"s\") * 4,\n", " events_per_pulse=200_000,\n", @@ -219,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "10", "metadata": {}, "source": [ "The initial birth times and wavelengths of the generated neutrons can be visualized (for a single pulse):" @@ -228,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -239,7 +211,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +220,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "13", "metadata": {}, "source": [ "From this fake beamline, we extract the raw neutron signal at our detector:" @@ -257,7 +229,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -269,7 +241,7 @@ }, { "cell_type": "markdown", - "id": "17", + "id": "15", "metadata": {}, "source": [ "The total number of neutrons in our sample data that make it through the to detector is:" @@ -278,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -287,165 +259,107 @@ }, { "cell_type": "markdown", - "id": "19", - "metadata": {}, - "source": [ - "## Using the chopper cascade to chop the pulse\n", - "\n", - "The `chopper_cascade` module can now be used to chop a pulse of neutrons using the choppers created above." - ] - }, - { - "cell_type": "markdown", - "id": "20", + "id": "17", "metadata": {}, "source": [ - "### Create a pulse of neutrons\n", + "## Computing time-of-flight\n", "\n", - "We then create a (fake) pulse of neutrons, whose time and wavelength ranges are close to that of our ESS pulse above:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [ - "pulse_tmin = one_pulse.coords[\"time\"].min()\n", - "pulse_tmax = one_pulse.coords[\"time\"].max()\n", - "pulse_wmin = one_pulse.coords[\"wavelength\"].min()\n", - "pulse_wmax = one_pulse.coords[\"wavelength\"].max()\n", + "The chopper information is next used to construct a lookup table that provides an estimate of the real time-of-flight as a function of neutron time-of-arrival.\n", "\n", - "frames = chopper_cascade.FrameSequence.from_source_pulse(\n", - " time_min=pulse_tmin,\n", - " time_max=pulse_tmax,\n", - " wavelength_min=pulse_wmin,\n", - " wavelength_max=pulse_wmax,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "22", - "metadata": {}, - "source": [ - "### Propagate the neutrons through the choppers\n", + "We use the `tof` module to propagate a pulse of neutrons through the chopper system to the detectors,\n", + "and predict the most likely neutron wavelength for a given time-of-arrival and distance from source.\n", "\n", - "We are now able to propagate the pulse of neutrons through the chopper cascade,\n", - "chopping away the parts of the pulse that do not make it through.\n", + "From this,\n", + "we build a lookup table on which bilinear interpolation is used to compute a wavelength (and its corresponding time-of-flight)\n", + "for every neutron event.\n", "\n", - "For this, we need to decide how far we want to propagate the neutrons, by choosing a distance to our detector.\n", - "We set this to 32 meters here." + "### Setting up the workflow" ] }, { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "18", "metadata": {}, "outputs": [], "source": [ - "# Chop the frames\n", - "chopped = frames.chop(choppers.values())\n", + "workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())\n", "\n", - "# Propagate the neutrons to the detector\n", - "at_sample = chopped.propagate_to(Ltotal)\n", + "workflow[unwrap.Facility] = \"ess\"\n", + "workflow[unwrap.Choppers] = disk_choppers\n", + "workflow[unwrap.RawData] = raw_data\n", + "workflow[unwrap.Ltotal] = Ltotal\n", "\n", - "# Visualize the results\n", - "cascade_fig, cascade_ax = at_sample.draw()" - ] - }, - { - "cell_type": "markdown", - "id": "24", - "metadata": {}, - "source": [ - "We can now see that at the detector (pink color), we have 2 sub-pulses of neutrons,\n", - "where the longest wavelength of one frame is very close to the shortest wavelength of the next frame." + "workflow.visualize(unwrap.TofData)" ] }, { "cell_type": "markdown", - "id": "25", + "id": "19", "metadata": {}, "source": [ - "## Computing the time-of-flight coordinate using Sciline\n", + "### Inspecting the lookup table\n", + "\n", + "The workflow first runs a `tof` simulation using the chopper parameters above,\n", + "and the result is stored in `SimulationResults` (see graph above).\n", "\n", - "### Setting up the workflow\n", + "From these simulated neutrons, we create figures displaying the neutron wavelengths and time-of-flight,\n", + "as a function of arrival time at the detector.\n", "\n", - "We will now construct a workflow to compute the `time-of-flight` coordinate from the neutron events above,\n", - "taking into account the choppers in the beamline." + "This is the basis for creating our lookup table." ] }, { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "20", "metadata": {}, "outputs": [], "source": [ - "workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())\n", - "\n", - "workflow[unwrap.PulsePeriod] = sc.reciprocal(ess_beamline.source.frequency)\n", - "workflow[unwrap.SourceTimeRange] = pulse_tmin, pulse_tmax\n", - "workflow[unwrap.SourceWavelengthRange] = pulse_wmin, pulse_wmax\n", - "workflow[unwrap.Choppers] = choppers\n", - "\n", - "workflow[unwrap.Ltotal] = Ltotal\n", - "workflow[unwrap.RawData] = raw_data\n", + "sim = workflow.compute(unwrap.SimulationResults)\n", + "# Compute time-of-arrival at the detector\n", + "tarrival = sim.time_of_arrival + ((Ltotal - sim.distance) / sim.speed).to(unit=\"us\")\n", + "# Compute time-of-flight at the detector\n", + "tflight = (Ltotal / sim.speed).to(unit=\"us\")\n", "\n", - "workflow.visualize(unwrap.TofData)" + "events = sc.DataArray(\n", + " data=sim.weight,\n", + " coords={\"wavelength\": sim.wavelength, \"toa\": tarrival, \"tof\": tflight},\n", + ")\n", + "fig1 = events.hist(wavelength=300, toa=300).plot(norm=\"log\")\n", + "fig2 = events.hist(tof=300, toa=300).plot(norm=\"log\")\n", + "fig1 + fig2" ] }, { "cell_type": "markdown", - "id": "27", + "id": "21", "metadata": {}, "source": [ - "### Checking the frame bounds\n", + "The lookup table is then obtained by computing the weighted mean of the time-of-flight inside each time-of-arrival bin.\n", "\n", - "We can check that the bounds for the frames the workflow computes agrees with the chopper-cascade diagram" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28", - "metadata": {}, - "outputs": [], - "source": [ - "bounds = workflow.compute(unwrap.FrameAtDetector).subbounds()\n", - "bounds" + "This is illustrated by the orange line in the figure below:" ] }, { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "22", "metadata": {}, "outputs": [], "source": [ - "for b in sc.collapse(bounds[\"time\"], keep=\"bound\").values():\n", - " cascade_ax.axvspan(\n", - " b[0].to(unit=\"ms\").value,\n", - " b[1].to(unit=\"ms\").value,\n", - " color=\"gray\",\n", - " alpha=0.2,\n", - " zorder=-5,\n", - " )\n", + "table = workflow.compute(unwrap.TimeOfFlightLookupTable)\n", "\n", - "cascade_fig" + "# Overlay mean on the figure above\n", + "table[\"distance\", 1].plot(ax=fig2.ax, color=\"C1\")" ] }, { "attachments": {}, "cell_type": "markdown", - "id": "30", + "id": "23", "metadata": {}, "source": [ - "There should be one vertical band matching the extent of each pink polygon, which there is.\n", - "\n", "### Computing a time-of-flight coordinate\n", "\n", "We will now use our workflow to obtain our event data with a time-of-flight coordinate:" @@ -454,7 +368,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -464,7 +378,7 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "25", "metadata": {}, "source": [ "Histogramming the data for a plot should show a profile with 6 bumps that correspond to the frames:" @@ -473,7 +387,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -482,7 +396,7 @@ }, { "cell_type": "markdown", - "id": "34", + "id": "27", "metadata": {}, "source": [ "### Converting to wavelength\n", @@ -493,7 +407,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -512,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "29", "metadata": {}, "source": [ "### Comparing to the ground truth\n", @@ -524,7 +438,7 @@ { "cell_type": "code", "execution_count": null, - "id": "37", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -541,7 +455,7 @@ }, { "cell_type": "markdown", - "id": "38", + "id": "31", "metadata": {}, "source": [ "## Multiple detector pixels\n", @@ -557,17 +471,15 @@ { "cell_type": "code", "execution_count": null, - "id": "39", + "id": "32", "metadata": {}, "outputs": [], "source": [ - "Ltotal = sc.array(dims=['detector_number'], values=[77.675, 76.0], unit='m')\n", - "monitors = {\n", - " f\"detector{i}\": ltot for i, ltot in enumerate(Ltotal)\n", - " }\n", + "Ltotal = sc.array(dims=[\"detector_number\"], values=[77.675, 76.0], unit=\"m\")\n", + "monitors = {f\"detector{i}\": ltot for i, ltot in enumerate(Ltotal)}\n", "\n", "ess_beamline = FakeBeamlineEss(\n", - " choppers=choppers,\n", + " choppers=disk_choppers,\n", " monitors=monitors,\n", " run_length=sc.scalar(1 / 14, unit=\"s\") * 4,\n", " events_per_pulse=200_000,\n", @@ -576,7 +488,7 @@ }, { "cell_type": "markdown", - "id": "40", + "id": "33", "metadata": {}, "source": [ "Our raw data has now a `detector_number` dimension of length 2.\n", @@ -587,22 +499,26 @@ { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "34", "metadata": {}, "outputs": [], "source": [ "raw_data = sc.concat(\n", - " [ess_beamline.get_monitor(key)[0] for key in monitors.keys()],\n", - " dim='detector_number',\n", - " )\n", + " [ess_beamline.get_monitor(key)[0] for key in monitors.keys()],\n", + " dim=\"detector_number\",\n", + ")\n", "\n", "# Visualize\n", - "pp.plot(sc.collapse(raw_data.hist(event_time_offset=300).sum(\"pulse\"), keep='event_time_offset'))" + "pp.plot(\n", + " sc.collapse(\n", + " raw_data.hist(event_time_offset=300).sum(\"pulse\"), keep=\"event_time_offset\"\n", + " )\n", + ")" ] }, { "cell_type": "markdown", - "id": "42", + "id": "35", "metadata": {}, "source": [ "Computing time-of-flight is done in the same way as above.\n", @@ -612,7 +528,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -633,10 +549,12 @@ "figs = [\n", " pp.plot(\n", " {\n", - " \"wfm\": wav_wfm['detector_number', i].bins.concat().hist(wavelength=wavs),\n", + " \"wfm\": wav_wfm[\"detector_number\", i].bins.concat().hist(wavelength=wavs),\n", " \"ground_truth\": ground_truth[i].hist(wavelength=wavs),\n", - " }, title=f\"Pixel {i+1}\"\n", - " ) for i in range(len(Ltotal))\n", + " },\n", + " title=f\"Pixel {i+1}\",\n", + " )\n", + " for i in range(len(Ltotal))\n", "]\n", "\n", "figs[0] + figs[1]" @@ -644,7 +562,7 @@ }, { "cell_type": "markdown", - "id": "44", + "id": "37", "metadata": {}, "source": [ "## Handling time overlap between subframes\n", @@ -656,7 +574,7 @@ "but arrive at the same time at the detector.\n", "\n", "In this case, it is actually not possible to accurately determine the wavelength of the neutrons.\n", - "ScippNeutron handles this by clipping the overlapping regions and throwing away any neutrons that lie within it.\n", + "ScippNeutron handles this by masking the overlapping regions and throwing away any neutrons that lie within it.\n", "\n", "To simulate this, we modify slightly the phase and the cutouts of the band-control chopper:" ] @@ -664,11 +582,11 @@ { "cell_type": "code", "execution_count": null, - "id": "45", + "id": "38", "metadata": {}, "outputs": [], "source": [ - "disk_choppers['bcc'] = DiskChopper(\n", + "disk_choppers[\"bcc\"] = DiskChopper(\n", " frequency=sc.scalar(112.0, unit=\"Hz\"),\n", " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", " phase=sc.scalar(240 - 180, unit=\"deg\"),\n", @@ -679,67 +597,114 @@ " radius=sc.scalar(30.0, unit=\"cm\"),\n", ")\n", "\n", - "\n", - "# Update the choppers\n", - "choppers = {\n", - " key: chopper_cascade.Chopper.from_disk_chopper(\n", - " chop,\n", - " pulse_frequency=sc.scalar(14.0, unit=\"Hz\"),\n", - " npulses=1,\n", - " )\n", - " for key, chop in disk_choppers.items()\n", - "}\n", - "\n", "# Go back to a single detector pixel\n", "Ltotal = sc.scalar(76.55 + 1.125, unit=\"m\")\n", "\n", "ess_beamline = FakeBeamlineEss(\n", - " choppers=choppers,\n", + " choppers=disk_choppers,\n", " monitors={\"detector\": Ltotal},\n", " run_length=sc.scalar(1 / 14, unit=\"s\") * 4,\n", " events_per_pulse=200_000,\n", ")\n", "\n", + "# Update workflow\n", + "workflow[unwrap.Ltotal] = Ltotal\n", + "workflow[unwrap.RawData] = ess_beamline.get_monitor(\"detector\")[0]\n", + "\n", "ess_beamline.model_result.plot()" ] }, { "cell_type": "markdown", - "id": "46", + "id": "39", "metadata": {}, "source": [ "We can now see that there is no longer a gap between the two frames at the center of each pulse (green region).\n", "\n", - "Another way of looking at this is making the chopper-cascade diagram and zooming in on the last frame (pink),\n", - "which also show overlap in time:" + "Another way of looking at this is looking at the wavelength vs time-of-arrival plot,\n", + "which also shows overlap in time at the junction between the two frames:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [], + "source": [ + "sim = workflow.compute(unwrap.SimulationResults)\n", + "# Compute time-of-arrival at the detector\n", + "tarrival = sim.time_of_arrival + ((Ltotal - sim.distance) / sim.speed).to(unit=\"us\")\n", + "# Compute time-of-flight at the detector\n", + "tflight = (Ltotal / sim.speed).to(unit=\"us\")\n", + "\n", + "events = sc.DataArray(\n", + " data=sim.weight,\n", + " coords={\"wavelength\": sim.wavelength, \"toa\": tarrival, \"tof\": tflight},\n", + ")\n", + "events.hist(wavelength=300, toa=300).plot(norm=\"log\")" + ] + }, + { + "cell_type": "markdown", + "id": "41", + "metadata": {}, + "source": [ + "The data in the lookup table contains both the mean time-of-flight for each distance and time-of-arrival bin,\n", + "but also the variance inside each bin.\n", + "\n", + "In the regions where there is no time overlap,\n", + "the variance is small (the regions are close to a thin line).\n", + "However, in the central region where overlap occurs,\n", + "we are computing a mean between two regions which have similar 'brightness'.\n", + "\n", + "This leads to a large variance, and this is visible when plotting the variances on a 2D figure." ] }, { "cell_type": "code", "execution_count": null, - "id": "47", + "id": "42", "metadata": {}, "outputs": [], "source": [ - "# Chop the frames\n", - "chopped = frames.chop(choppers.values())\n", + "table = workflow.compute(unwrap.TimeOfFlightLookupTable)\n", + "table.plot() / sc.variances(table).plot(norm=\"log\")" + ] + }, + { + "cell_type": "markdown", + "id": "43", + "metadata": {}, + "source": [ + "The workflow has a parameter which is used to mask out regions where the variance is above a certain threshold.\n", "\n", - "# Propagate the neutrons to the detector\n", - "at_sample = chopped.propagate_to(Ltotal)\n", + "It is difficult to automatically detector this threshold,\n", + "as it can vary a lot depending on how much signal is received by the detectors,\n", + "and how far the detectors are from the source.\n", + "It is thus more robust to simply have a user tunable parameter on the workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [], + "source": [ + "workflow[unwrap.LookupTableVarianceThreshold] = 1.0e-3\n", "\n", - "# Visualize the results\n", - "cascade_fig, cascade_ax = at_sample.draw()\n", - "cascade_ax.set_xlim(45.0, 58.0)\n", - "cascade_ax.set_ylim(2.0, 3.0)" + "workflow.compute(unwrap.MaskedTimeOfFlightLookupTable).plot()" ] }, { "cell_type": "markdown", - "id": "48", + "id": "45", "metadata": {}, "source": [ - "To avoid the overlap in time, the region in the middle will be excluded,\n", - "discarding the neutrons from the time-of-flight calculation\n", + "We can now see that the central region is masked out.\n", + "\n", + "The neutrons in that region will be discarded in the time-of-flight calculation\n", "(in practice, they are given a NaN value as a time-of-flight).\n", "\n", "This is visible when comparing to the true neutron wavelengths,\n", @@ -749,14 +714,10 @@ { "cell_type": "code", "execution_count": null, - "id": "49", + "id": "46", "metadata": {}, "outputs": [], "source": [ - "workflow[unwrap.Choppers] = choppers\n", - "workflow[unwrap.Ltotal] = Ltotal\n", - "workflow[unwrap.RawData] = ess_beamline.get_monitor(\"detector\")[0]\n", - "\n", "# Compute time-of-flight\n", "tofs = workflow.compute(unwrap.TofData)\n", "# Compute wavelength\n", diff --git a/docs/user-guide/wfm/wfm-time-of-flight.ipynb b/docs/user-guide/wfm/wfm-time-of-flight.ipynb index c76fa16b3..6d55b8c21 100644 --- a/docs/user-guide/wfm/wfm-time-of-flight.ipynb +++ b/docs/user-guide/wfm/wfm-time-of-flight.ipynb @@ -26,8 +26,7 @@ "import scipp as sc\n", "import sciline as sl\n", "from scippneutron.chopper import DiskChopper\n", - "from scippneutron.tof import unwrap\n", - "from scippneutron.tof import chopper_cascade" + "from scippneutron.tof import unwrap" ] }, { @@ -200,31 +199,6 @@ "cell_type": "markdown", "id": "8", "metadata": {}, - "source": [ - "### Convert the choppers\n", - "\n", - "Lastly, we convert our disk choppers to a simpler chopper representation used by the `chopper_cascade` module. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "choppers = {\n", - " key: chopper_cascade.Chopper.from_disk_chopper(\n", - " chop, pulse_frequency=sc.scalar(14.0, unit=\"Hz\"), npulses=1\n", - " )\n", - " for key, chop in disk_choppers.items()\n", - "}" - ] - }, - { - "cell_type": "markdown", - "id": "10", - "metadata": {}, "source": [ "## Creating some neutron events\n", "\n", @@ -234,14 +208,14 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "9", "metadata": {}, "outputs": [], "source": [ "from scippneutron.tof.fakes import FakeBeamlineEss\n", "\n", "ess_beamline = FakeBeamlineEss(\n", - " choppers=choppers,\n", + " choppers=disk_choppers,\n", " monitors={\"detector\": Ltotal},\n", " run_length=sc.scalar(1 / 14, unit=\"s\") * 14,\n", " events_per_pulse=200_000,\n", @@ -250,7 +224,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "10", "metadata": {}, "source": [ "The initial birth times and wavelengths of the generated neutrons can be visualized (for a single pulse):" @@ -259,7 +233,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -269,7 +243,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "12", "metadata": {}, "source": [ "From this fake beamline, we extract the raw neutron signal at our detector:" @@ -278,7 +252,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -290,7 +264,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "14", "metadata": {}, "source": [ "The total number of neutrons in our sample data that make it through the to detector is:" @@ -299,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -308,165 +282,113 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "16", "metadata": {}, "source": [ - "## Using the chopper cascade to chop the pulse\n", + "## Computing time-of-flight\n", "\n", - "The `chopper_cascade` module can now be used to chop a pulse of neutrons using the choppers created above." - ] - }, - { - "cell_type": "markdown", - "id": "19", - "metadata": {}, - "source": [ - "### Create a pulse of neutrons\n", + "The chopper information is next used to construct a lookup table that provides an estimate of the real time-of-flight as a function of neutron time-of-arrival.\n", "\n", - "We then create a (fake) pulse of neutrons, whose time and wavelength ranges cover our ESS pulse above:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": {}, - "outputs": [], - "source": [ - "time_min = sc.scalar(0.0, unit='ms')\n", - "time_max = sc.scalar(3.4, unit='ms')\n", - "wavs_min = sc.scalar(0.01, unit='angstrom')\n", - "wavs_max = sc.scalar(10.0, unit='angstrom')\n", - "\n", - "frames = chopper_cascade.FrameSequence.from_source_pulse(\n", - " time_min=time_min,\n", - " time_max=time_max,\n", - " wavelength_min=wavs_min,\n", - " wavelength_max=wavs_max,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "21", - "metadata": {}, - "source": [ - "### Propagate the neutrons through the choppers\n", + "We use the `tof` module to propagate a pulse of neutrons through the chopper system to the detectors,\n", + "and predict the most likely neutron wavelength for a given time-of-arrival and distance from source.\n", "\n", - "We are now able to propagate the pulse of neutrons through the chopper cascade,\n", - "chopping away the parts of the pulse that do not make it through.\n", + "From this,\n", + "we build a lookup table on which bilinear interpolation is used to compute a wavelength (and its corresponding time-of-flight)\n", + "for every neutron event.\n", "\n", - "For this, we need to decide how far we want to propagate the neutrons, by choosing a distance to our detector.\n", - "We set this to 32 meters here." + "### Setting up the workflow" ] }, { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "17", "metadata": {}, "outputs": [], "source": [ - "# Chop the frames\n", - "chopped = frames.chop(choppers.values())\n", + "workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())\n", "\n", - "# Propagate the neutrons to the detector\n", - "at_sample = chopped.propagate_to(Ltotal)\n", + "workflow[unwrap.Facility] = \"ess\"\n", + "workflow[unwrap.Choppers] = disk_choppers\n", + "workflow[unwrap.RawData] = raw_data\n", + "workflow[unwrap.Ltotal] = Ltotal\n", + "workflow[unwrap.NumberOfNeutrons] = 3_000_000\n", "\n", - "# Visualize the results\n", - "cascade_fig, cascade_ax = at_sample.draw()" - ] - }, - { - "cell_type": "markdown", - "id": "23", - "metadata": {}, - "source": [ - "We can now see that at the detector (pink color), we have 6 sub-pulses of neutrons,\n", - "where the longest wavelength of one frame is very close to the shortest wavelength of the next frame." + "workflow.visualize(unwrap.TofData)" ] }, { "cell_type": "markdown", - "id": "24", + "id": "18", "metadata": {}, "source": [ - "## Computing the time-of-flight coordinate using Sciline\n", + "### Inspecting the lookup table\n", "\n", - "### Setting up the workflow\n", + "The workflow first runs a `tof` simulation using the chopper parameters above,\n", + "and the result is stored in `SimulationResults` (see graph above).\n", "\n", - "We will now construct a workflow to compute the `time-of-flight` coordinate from the neutron events above,\n", - "taking into account the choppers in the beamline." + "From these simulated neutrons, we create figures displaying the neutron wavelengths and time-of-flight,\n", + "as a function of arrival time at the detector.\n", + "\n", + "This is the basis for creating our lookup table." ] }, { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "19", "metadata": {}, "outputs": [], "source": [ - "workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())\n", - "\n", - "workflow[unwrap.PulsePeriod] = sc.reciprocal(ess_beamline.source.frequency)\n", - "workflow[unwrap.SourceTimeRange] = time_min, time_max\n", - "workflow[unwrap.SourceWavelengthRange] = wavs_min, wavs_max\n", - "workflow[unwrap.Choppers] = choppers\n", - "\n", - "workflow[unwrap.Ltotal] = Ltotal\n", - "workflow[unwrap.RawData] = raw_data\n", + "sim = workflow.compute(unwrap.SimulationResults)\n", + "# Compute time-of-arrival at the detector\n", + "tarrival = sim.time_of_arrival + ((Ltotal - sim.distance) / sim.speed).to(unit=\"us\")\n", + "# Compute time-of-flight at the detector\n", + "tflight = (Ltotal / sim.speed).to(unit=\"us\")\n", "\n", - "workflow.visualize(unwrap.TofData)" + "events = sc.DataArray(\n", + " data=sim.weight,\n", + " coords={\"wavelength\": sim.wavelength, \"toa\": tarrival, \"tof\": tflight},\n", + ")\n", + "fig1 = events.hist(wavelength=300, toa=300).plot(norm=\"log\")\n", + "fig2 = events.hist(tof=300, toa=300).plot(norm=\"log\")\n", + "fig1 + fig2" ] }, { "cell_type": "markdown", - "id": "26", + "id": "20", "metadata": {}, "source": [ - "### Checking the frame bounds\n", + "The lookup table is then obtained by computing the weighted mean of the time-of-flight inside each time-of-arrival bin.\n", "\n", - "We can check that the bounds for the frames the workflow computes agrees with the chopper-cascade diagram" + "This is illustrated by the orange line in the figure below:" ] }, { "cell_type": "code", "execution_count": null, - "id": "27", - "metadata": {}, - "outputs": [], - "source": [ - "bounds = workflow.compute(unwrap.FrameAtDetector).subbounds()\n", - "bounds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28", + "id": "21", "metadata": {}, "outputs": [], "source": [ - "for b in sc.collapse(bounds[\"time\"], keep=\"bound\").values():\n", - " cascade_ax.axvspan(\n", - " b[0].to(unit=\"ms\").value,\n", - " b[1].to(unit=\"ms\").value,\n", - " color=\"gray\",\n", - " alpha=0.3,\n", - " zorder=-5,\n", - " )\n", - "\n", - "cascade_fig" + "table = workflow.compute(unwrap.TimeOfFlightLookupTable)\n", + "\n", + "# Overlay mean on the figure above\n", + "table[\"distance\", 1].plot(ax=fig2.ax, color=\"C1\")\n", + "\n", + "# Zoom in\n", + "fig2.canvas.xrange = 40000, 50000\n", + "fig2.canvas.yrange = 35000, 50000\n", + "fig2" ] }, { "attachments": {}, "cell_type": "markdown", - "id": "29", + "id": "22", "metadata": {}, "source": [ - "There should be one vertical band matching the extent of each pink polygon, which there is.\n", - "\n", "### Computing a time-of-flight coordinate\n", "\n", "We will now use our workflow to obtain our event data with a time-of-flight coordinate:" @@ -475,7 +397,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -485,7 +407,7 @@ }, { "cell_type": "markdown", - "id": "31", + "id": "24", "metadata": {}, "source": [ "Histogramming the data for a plot should show a profile with 6 bumps that correspond to the frames:" @@ -494,7 +416,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -504,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "33", + "id": "26", "metadata": {}, "source": [ "### Converting to wavelength\n", @@ -515,7 +437,7 @@ { "cell_type": "code", "execution_count": null, - "id": "34", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -534,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "35", + "id": "28", "metadata": {}, "source": [ "### Comparing to the ground truth\n", @@ -546,7 +468,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36", + "id": "29", "metadata": {}, "outputs": [], "source": [ diff --git a/pyproject.toml b/pyproject.toml index e3e9be32c..be8c10149 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,9 +35,11 @@ dependencies = [ "mpltoolbox", "numpy>=1.20", "plopp>=24.09.1", + "sciline", "scipp>=23.07.0", "scippnexus>=23.11.0", "scipy>=1.7.0", + "tof>=25.01.2", ] [project.optional-dependencies] @@ -51,7 +53,6 @@ test = [ "pytest", "pytest-xdist", "pythreejs", - "sciline", ] [project.urls] diff --git a/requirements/base.in b/requirements/base.in index 0828421a3..f62816c3e 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -6,6 +6,8 @@ h5py mpltoolbox numpy>=1.20 plopp>=24.09.1 +sciline scipp>=23.07.0 scippnexus>=23.11.0 scipy>=1.7.0 +tof>=25.01.2 diff --git a/requirements/base.txt b/requirements/base.txt index 5188caf85..24361d9cf 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,4 +1,4 @@ -# SHA1:8028b204890e7f7923037a4acbc74a34e7834749 +# SHA1:2e64d311bfcc14c32a4f22c48a21493d5dd212e7 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -7,6 +7,8 @@ # contourpy==1.3.1 # via matplotlib +cyclebane==24.10.0 + # via sciline cycler==0.12.1 # via matplotlib fonttools==4.55.3 @@ -15,15 +17,19 @@ h5py==3.12.1 # via # -r base.in # scippnexus -kiwisolver==1.4.7 +importlib-resources==6.5.2 + # via tof +kiwisolver==1.4.8 # via matplotlib -matplotlib==3.9.3 +matplotlib==3.10.0 # via # mpltoolbox # plopp mpltoolbox==24.5.1 # via -r base.in -numpy==2.2.0 +networkx==3.4.2 + # via cyclebane +numpy==2.2.1 # via # -r base.in # contourpy @@ -34,25 +40,33 @@ numpy==2.2.0 # scipy packaging==24.2 # via matplotlib -pillow==11.0.0 +pillow==11.1.0 # via matplotlib plopp==24.10.0 - # via -r base.in -pyparsing==3.2.0 + # via + # -r base.in + # tof +pyparsing==3.2.1 # via matplotlib python-dateutil==2.9.0.post0 # via # matplotlib # scippnexus +sciline==24.10.0 + # via -r base.in scipp==24.11.2 # via # -r base.in # scippnexus + # tof scippnexus==24.11.1 # via -r base.in -scipy==1.14.1 +scipy==1.15.1 # via # -r base.in # scippnexus + # tof six==1.17.0 # via python-dateutil +tof==25.1.2 + # via -r base.in diff --git a/requirements/basetest.in b/requirements/basetest.in index 554828ae1..66346df9c 100644 --- a/requirements/basetest.in +++ b/requirements/basetest.in @@ -15,4 +15,3 @@ psutil pytest pytest-xdist pythreejs -sciline diff --git a/requirements/basetest.txt b/requirements/basetest.txt index 9c55c1418..907fb5d9e 100644 --- a/requirements/basetest.txt +++ b/requirements/basetest.txt @@ -1,4 +1,4 @@ -# SHA1:3605566d5f9fcb54bf956185636936ff6ede5b36 +# SHA1:8bd1f136c0aede63207a0ce2d0da4202e31b3494 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -9,25 +9,23 @@ appdirs==1.4.4 # via pace-neutrons asttokens==3.0.0 # via stack-data -attrs==24.2.0 +attrs==24.3.0 # via hypothesis brille==0.8.0 # via pace-neutrons -certifi==2024.8.30 +certifi==2024.12.14 # via requests -charset-normalizer==3.4.0 +charset-normalizer==3.4.1 # via requests comm==0.2.2 # via ipywidgets contourpy==1.3.1 # via matplotlib -cyclebane==24.10.0 - # via sciline cycler==0.12.1 # via matplotlib decorator==5.1.1 # via ipython -euphonic[phonopy-reader]==1.4.0 +euphonic[phonopy-reader]==1.4.0.post1 # via pace-neutrons exceptiongroup==1.2.2 # via @@ -46,7 +44,7 @@ fonttools==4.55.3 # via matplotlib h5py==3.12.1 # via euphonic -hypothesis==6.122.3 +hypothesis==6.124.0 # via -r basetest.in idna==3.10 # via requests @@ -54,15 +52,13 @@ iniconfig==2.0.0 # via pytest ipydatawidgets==4.3.5 # via pythreejs -ipympl==0.9.4 +ipympl==0.9.6 # via -r basetest.in -ipython==8.30.0 +ipython==8.31.0 # via # ipympl # ipywidgets # pace-neutrons -ipython-genutils==0.2.0 - # via ipympl ipywidgets==8.1.5 # via # ipydatawidgets @@ -72,19 +68,17 @@ jedi==0.19.2 # via ipython jupyterlab-widgets==3.0.13 # via ipywidgets -kiwisolver==1.4.7 +kiwisolver==1.4.8 # via matplotlib libpymcr==0.1.8 # via pace-neutrons -matplotlib==3.9.3 +matplotlib==3.10.0 # via # ipympl # pace-neutrons matplotlib-inline==0.1.7 # via ipython -networkx==3.4.2 - # via cyclebane -numpy==2.2.0 +numpy==2.2.1 # via # brille # contourpy @@ -111,7 +105,7 @@ parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -pillow==11.0.0 +pillow==11.1.0 # via # ipympl # matplotlib @@ -127,7 +121,7 @@ pooch==1.8.2 # via -r basetest.in prompt-toolkit==3.0.48 # via ipython -psutil==6.1.0 +psutil==6.1.1 # via # -r basetest.in # pace-neutrons @@ -135,9 +129,9 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pygments==2.18.0 +pygments==2.19.1 # via ipython -pyparsing==3.2.0 +pyparsing==3.2.1 # via matplotlib pytest==8.3.4 # via @@ -155,9 +149,7 @@ requests==2.32.3 # via # pace-neutrons # pooch -sciline==24.10.0 - # via -r basetest.in -scipy==1.14.1 +scipy==1.15.1 # via euphonic seekpath==2.1.0 # via euphonic @@ -196,7 +188,7 @@ typing-extensions==4.12.2 # flexparser # ipython # pint -urllib3==2.2.3 +urllib3==2.3.0 # via requests wcwidth==0.2.13 # via prompt-toolkit diff --git a/requirements/ci.txt b/requirements/ci.txt index 045acba62..5ed400510 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -7,11 +7,11 @@ # cachetools==5.5.0 # via tox -certifi==2024.8.30 +certifi==2024.12.14 # via requests chardet==5.2.0 # via tox -charset-normalizer==3.4.0 +charset-normalizer==3.4.1 # via requests colorama==0.4.6 # via tox @@ -21,9 +21,9 @@ filelock==3.16.1 # via # tox # virtualenv -gitdb==4.0.11 +gitdb==4.0.12 # via gitpython -gitpython==3.1.43 +gitpython==3.1.44 # via -r ci.in idna==3.10 # via requests @@ -42,7 +42,7 @@ pyproject-api==1.8.0 # via tox requests==2.32.3 # via -r ci.in -smmap==5.0.1 +smmap==5.0.2 # via gitdb tomli==2.2.1 # via @@ -52,7 +52,7 @@ tox==4.23.2 # via -r ci.in typing-extensions==4.12.2 # via tox -urllib3==2.2.3 +urllib3==2.3.0 # via requests -virtualenv==20.28.0 +virtualenv==20.29.0 # via tox diff --git a/requirements/dev.txt b/requirements/dev.txt index faa105fa6..9efc4d9fa 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -14,7 +14,7 @@ -r wheels.txt annotated-types==0.7.0 # via pydantic -anyio==4.7.0 +anyio==4.8.0 # via # httpx # jupyter-server @@ -28,7 +28,7 @@ async-lru==2.0.4 # via jupyterlab cffi==1.17.1 # via argon2-cffi-bindings -click==8.1.7 +click==8.1.8 # via # pip-compile-multi # pip-tools @@ -59,11 +59,11 @@ jsonschema[format-nongpl]==4.23.0 # jupyter-events # jupyterlab-server # nbformat -jupyter-events==0.10.0 +jupyter-events==0.11.0 # via jupyter-server jupyter-lsp==2.2.5 # via jupyterlab -jupyter-server==2.14.2 +jupyter-server==2.15.0 # via # jupyter-lsp # jupyterlab @@ -71,7 +71,7 @@ jupyter-server==2.14.2 # notebook-shim jupyter-server-terminals==0.5.3 # via jupyter-server -jupyterlab==4.3.3 +jupyterlab==4.3.4 # via -r dev.in jupyterlab-server==2.27.3 # via jupyterlab @@ -91,13 +91,13 @@ prometheus-client==0.21.1 # via jupyter-server pycparser==2.22 # via cffi -pydantic==2.10.3 +pydantic==2.10.5 # via copier -pydantic-core==2.27.1 +pydantic-core==2.27.2 # via pydantic -python-json-logger==2.0.7 +python-json-logger==3.2.1 # via jupyter-events -questionary==1.10.0 +questionary==2.1.0 # via copier rfc3339-validator==0.1.4 # via diff --git a/requirements/docs.txt b/requirements/docs.txt index 0ad13bfac..80c534374 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -12,7 +12,7 @@ alabaster==1.0.0 # via sphinx asttokens==3.0.0 # via stack-data -attrs==24.2.0 +attrs==24.3.0 # via # jsonschema # referencing @@ -24,17 +24,17 @@ beautifulsoup4==4.12.3 # via # nbconvert # pydata-sphinx-theme -bleach==6.2.0 +bleach[css]==6.2.0 # via nbconvert -certifi==2024.8.30 +certifi==2024.12.14 # via requests -charset-normalizer==3.4.0 +charset-normalizer==3.4.1 # via requests comm==0.2.2 # via # ipykernel # ipywidgets -debugpy==1.8.10 +debugpy==1.8.11 # via ipykernel decorator==5.1.1 # via ipython @@ -58,25 +58,21 @@ idna==3.10 # via requests imagesize==1.4.1 # via sphinx -importlib-resources==6.4.5 - # via tof ipykernel==6.29.5 # via -r docs.in -ipympl==0.9.4 +ipympl==0.9.6 # via -r docs.in -ipython==8.30.0 +ipython==8.31.0 # via # -r docs.in # ipykernel # ipympl # ipywidgets -ipython-genutils==0.2.0 - # via ipympl ipywidgets==8.1.5 # via ipympl jedi==0.19.2 # via ipython -jinja2==3.1.4 +jinja2==3.1.5 # via # myst-parser # nbconvert @@ -119,20 +115,20 @@ mdit-py-plugins==0.4.2 # via myst-parser mdurl==0.1.2 # via markdown-it-py -mistune==3.0.2 +mistune==3.1.0 # via nbconvert myst-parser==4.0.0 # via -r docs.in -nbclient==0.10.1 +nbclient==0.10.2 # via nbconvert -nbconvert==7.16.4 +nbconvert==7.16.5 # via nbsphinx nbformat==5.10.4 # via # nbclient # nbconvert # nbsphinx -nbsphinx==0.9.5 +nbsphinx==0.9.6 # via -r docs.in nest-asyncio==1.6.0 # via ipykernel @@ -150,7 +146,7 @@ pooch==1.8.2 # via -r docs.in prompt-toolkit==3.0.48 # via ipython -psutil==6.1.0 +psutil==6.1.1 # via ipykernel ptyprocess==0.7.0 # via pexpect @@ -162,9 +158,9 @@ pybtex==0.24.0 # sphinxcontrib-bibtex pybtex-docutils==1.0.3 # via sphinxcontrib-bibtex -pydata-sphinx-theme==0.16.0 +pydata-sphinx-theme==0.16.1 # via -r docs.in -pygments==2.18.0 +pygments==2.19.1 # via # accessible-pygments # ipython @@ -205,7 +201,7 @@ sphinx==8.1.3 # sphinx-copybutton # sphinx-design # sphinxcontrib-bibtex -sphinx-autodoc-typehints==2.5.0 +sphinx-autodoc-typehints==3.0.0 # via -r docs.in sphinx-copybutton==0.5.2 # via -r docs.in @@ -228,9 +224,7 @@ sphinxcontrib-serializinghtml==2.0.0 stack-data==0.6.3 # via ipython tinycss2==1.4.0 - # via nbconvert -tof==24.12.0 - # via -r docs.in + # via bleach tomli==2.2.1 # via sphinx tornado==6.4.2 @@ -254,8 +248,9 @@ traitlets==5.14.3 typing-extensions==4.12.2 # via # ipython + # mistune # pydata-sphinx-theme -urllib3==2.2.3 +urllib3==2.3.0 # via requests wcwidth==0.2.13 # via prompt-toolkit diff --git a/requirements/mypy.txt b/requirements/mypy.txt index 0b6fa4cce..d7a49e8a6 100644 --- a/requirements/mypy.txt +++ b/requirements/mypy.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r test.txt -mypy==1.13.0 +mypy==1.14.1 # via -r mypy.in mypy-extensions==1.0.0 # via mypy diff --git a/requirements/nightly.txt b/requirements/nightly.txt index c5505b2d6..62f71ac3e 100644 --- a/requirements/nightly.txt +++ b/requirements/nightly.txt @@ -9,13 +9,13 @@ appdirs==1.4.4 # via pace-neutrons asttokens==3.0.0 # via stack-data -attrs==24.2.0 +attrs==24.3.0 # via hypothesis brille==0.8.0 # via pace-neutrons -certifi==2024.8.30 +certifi==2024.12.14 # via requests -charset-normalizer==3.4.0 +charset-normalizer==3.4.1 # via requests comm==0.2.2 # via ipywidgets @@ -27,7 +27,7 @@ cycler==0.12.1 # via matplotlib decorator==5.1.1 # via ipython -euphonic[phonopy-reader]==1.4.0 +euphonic[phonopy-reader]==1.4.0.post1 # via pace-neutrons exceptiongroup==1.2.2 # via @@ -49,25 +49,23 @@ h5py==3.12.1 # -r nightly.in # euphonic # scippnexus -hypothesis==6.122.3 +hypothesis==6.124.0 # via -r nightly.in idna==3.10 # via requests -importlib-resources==6.4.5 +importlib-resources==6.5.2 # via tof iniconfig==2.0.0 # via pytest ipydatawidgets==4.3.5 # via pythreejs -ipympl==0.9.4 +ipympl==0.9.6 # via -r nightly.in -ipython==8.30.0 +ipython==8.31.0 # via # ipympl # ipywidgets # pace-neutrons -ipython-genutils==0.2.0 - # via ipympl ipywidgets==8.1.5 # via # ipydatawidgets @@ -77,11 +75,11 @@ jedi==0.19.2 # via ipython jupyterlab-widgets==3.0.13 # via ipywidgets -kiwisolver==1.4.7 +kiwisolver==1.4.8 # via matplotlib libpymcr==0.1.8 # via pace-neutrons -matplotlib==3.9.3 +matplotlib==3.10.0 # via # ipympl # mpltoolbox @@ -93,7 +91,7 @@ mpltoolbox @ git+https://github.com/scipp/mpltoolbox@main # via -r nightly.in networkx==3.4.2 # via cyclebane -numpy==2.2.0 +numpy==2.2.1 # via # -r nightly.in # brille @@ -123,7 +121,7 @@ parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -pillow==11.0.0 +pillow==11.1.0 # via # ipympl # matplotlib @@ -143,7 +141,7 @@ pooch==1.8.2 # via -r nightly.in prompt-toolkit==3.0.48 # via ipython -psutil==6.1.0 +psutil==6.1.1 # via # -r nightly.in # pace-neutrons @@ -151,9 +149,9 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pygments==2.18.0 +pygments==2.19.1 # via ipython -pyparsing==3.2.0 +pyparsing==3.2.1 # via matplotlib pytest==8.3.4 # via @@ -182,7 +180,7 @@ scipp @ https://github.com/scipp/scipp/releases/download/nightly/scipp-nightly-c # tof scippnexus @ git+https://github.com/scipp/scippnexus@main # via -r nightly.in -scipy==1.14.1 +scipy==1.15.1 # via # -r nightly.in # euphonic @@ -227,7 +225,7 @@ typing-extensions==4.12.2 # flexparser # ipython # pint -urllib3==2.2.3 +urllib3==2.3.0 # via requests wcwidth==0.2.13 # via prompt-toolkit diff --git a/requirements/static.txt b/requirements/static.txt index 1d6ae539b..3819f7726 100644 --- a/requirements/static.txt +++ b/requirements/static.txt @@ -11,7 +11,7 @@ distlib==0.3.9 # via virtualenv filelock==3.16.1 # via virtualenv -identify==2.6.3 +identify==2.6.5 # via pre-commit nodeenv==1.9.1 # via pre-commit @@ -21,5 +21,5 @@ pre-commit==4.0.1 # via -r static.in pyyaml==6.0.2 # via pre-commit -virtualenv==20.28.0 +virtualenv==20.29.0 # via pre-commit diff --git a/requirements/test.txt b/requirements/test.txt index 1a9a00ed7..cbb7a53b1 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -7,7 +7,3 @@ # -r base.txt -r basetest.txt -importlib-resources==6.4.5 - # via tof -tof==24.12.0 - # via -r test.in diff --git a/src/scippneutron/tof/__init__.py b/src/scippneutron/tof/__init__.py index 63c99168c..c841a6d43 100644 --- a/src/scippneutron/tof/__init__.py +++ b/src/scippneutron/tof/__init__.py @@ -10,8 +10,5 @@ from . import chopper_cascade, unwrap from .diagram import TimeDistanceDiagram -__all__ = [ - 'chopper_cascade', - 'unwrap', - 'TimeDistanceDiagram', -] + +__all__ = ['chopper_cascade', 'unwrap', 'TimeDistanceDiagram'] diff --git a/src/scippneutron/tof/chopper_cascade.py b/src/scippneutron/tof/chopper_cascade.py index 268d328dc..6e2cdbb4d 100644 --- a/src/scippneutron/tof/chopper_cascade.py +++ b/src/scippneutron/tof/chopper_cascade.py @@ -66,9 +66,21 @@ def __init__(self, time: sc.Variable, wavelength: sc.Variable): def __eq__(self, other: object) -> bool: if not isinstance(other, Subframe): return NotImplemented - return sc.identical(self.time, other.time) and sc.identical( - self.wavelength, other.wavelength + # Using sc.identical can lead to flaky behavior (different frames might have + # been obtained via a different operation order), so we use allclose instead. + same_time = sc.allclose( + self.time, + other.time, + rtol=sc.scalar(1.0e-12), + atol=sc.scalar(1.0e-16, unit=self.time.unit), ) + same_wavelength = sc.allclose( + self.wavelength, + other.wavelength, + rtol=sc.scalar(1.0e-12), + atol=sc.scalar(1.0e-16, unit=self.wavelength.unit), + ) + return same_time and same_wavelength def is_regular(self) -> bool: """ diff --git a/src/scippneutron/tof/fakes.py b/src/scippneutron/tof/fakes.py index c44f76644..81e92cf24 100644 --- a/src/scippneutron/tof/fakes.py +++ b/src/scippneutron/tof/fakes.py @@ -19,6 +19,7 @@ import scipp as sc from numpy import random +from ..chopper import DiskChopper from . import chopper_cascade @@ -233,7 +234,7 @@ def _fake_monitor( class FakeBeamlineEss: def __init__( self, - choppers: dict[str, chopper_cascade.Chopper], + choppers: dict[str, DiskChopper], monitors: dict[str, sc.Variable], run_length: sc.Variable, events_per_pulse: int = 200000, @@ -257,36 +258,20 @@ def __init__( self.source = source(pulses=self.npulses) # Convert the choppers to tof.Chopper - def _open_close_angles(chopper, frequency): - angular_speed = sc.constants.pi * (2.0 * sc.units.rad) * frequency - return ( - chopper.time_open * angular_speed, - chopper.time_close * angular_speed, - ) - - self.choppers = [] - for name, ch in choppers.items(): - frequency = self.frequency - open_angles, close_angles = _open_close_angles(ch, frequency) - # If the difference between open and close angles is larger than 2pi, - # the boundaries have crossed, which means that the chopper is rotating - # at a lower frequency. - two_pi = np.pi * 2 - if any(abs(np.diff(open_angles.values) > two_pi)) or any( - abs(np.diff(close_angles.values) > two_pi) - ): - frequency = 0.5 * frequency - open_angles, close_angles = _open_close_angles(ch, frequency) - self.choppers.append( - tof_pkg.Chopper( - frequency=frequency, - open=open_angles, - close=close_angles, - phase=sc.scalar(0.0, unit='rad'), - distance=ch.distance, - name=name, - ) + self.choppers = [ + tof_pkg.Chopper( + frequency=abs(ch.frequency), + direction=tof_pkg.AntiClockwise + if (ch.frequency.value > 0.0) + else tof_pkg.Clockwise, + open=ch.slit_begin, + close=ch.slit_end, + phase=abs(ch.phase), + distance=ch.axle_position.fields.z, + name=name, ) + for name, ch in choppers.items() + ] # Add detectors self.monitors = [ @@ -507,3 +492,140 @@ def get_monitor(self, name: str) -> sc.DataGroup: wavelength_max=ess_wavelength_max, ) psc_frames = psc_frames.chop(psc_choppers.values()) + + +wfm1_disk_chopper = DiskChopper( + frequency=sc.scalar(-70.0, unit="Hz"), + beam_position=sc.scalar(0.0, unit="deg"), + phase=sc.scalar(-47.10, unit="deg"), + axle_position=sc.vector(value=[0, 0, 6.6], unit="m"), + slit_begin=sc.array( + dims=["cutout"], + values=np.array([83.71, 140.49, 193.26, 242.32, 287.91, 330.3]) + 15.0, + unit="deg", + ), + slit_end=sc.array( + dims=["cutout"], + values=np.array([94.7, 155.79, 212.56, 265.33, 314.37, 360.0]) + 15.0, + unit="deg", + ), + slit_height=sc.scalar(10.0, unit="cm"), + radius=sc.scalar(30.0, unit="cm"), +) + +wfm2_disk_chopper = DiskChopper( + frequency=sc.scalar(-70.0, unit="Hz"), + beam_position=sc.scalar(0.0, unit="deg"), + phase=sc.scalar(-76.76, unit="deg"), + axle_position=sc.vector(value=[0, 0, 7.1], unit="m"), + slit_begin=sc.array( + dims=["cutout"], + values=np.array([65.04, 126.1, 182.88, 235.67, 284.73, 330.32]) + 15.0, + unit="deg", + ), + slit_end=sc.array( + dims=["cutout"], + values=np.array([76.03, 141.4, 202.18, 254.97, 307.74, 360.0]) + 15.0, + unit="deg", + ), + slit_height=sc.scalar(10.0, unit="cm"), + radius=sc.scalar(30.0, unit="cm"), +) + +foc1_disk_chopper = DiskChopper( + frequency=sc.scalar(-56.0, unit="Hz"), + beam_position=sc.scalar(0.0, unit="deg"), + phase=sc.scalar(-62.40, unit="deg"), + axle_position=sc.vector(value=[0, 0, 8.8], unit="m"), + slit_begin=sc.array( + dims=["cutout"], + values=np.array([74.6, 139.6, 194.3, 245.3, 294.8, 347.2]), + unit="deg", + ), + slit_end=sc.array( + dims=["cutout"], + values=np.array([95.2, 162.8, 216.1, 263.1, 310.5, 371.6]), + unit="deg", + ), + slit_height=sc.scalar(10.0, unit="cm"), + radius=sc.scalar(30.0, unit="cm"), +) + +foc2_disk_chopper = DiskChopper( + frequency=sc.scalar(-28.0, unit="Hz"), + beam_position=sc.scalar(0.0, unit="deg"), + phase=sc.scalar(-12.27, unit="deg"), + axle_position=sc.vector(value=[0, 0, 15.9], unit="m"), + slit_begin=sc.array( + dims=["cutout"], + values=np.array([98.0, 154.0, 206.8, 255.0, 299.0, 344.65]), + unit="deg", + ), + slit_end=sc.array( + dims=["cutout"], + values=np.array([134.6, 190.06, 237.01, 280.88, 323.56, 373.76]), + unit="deg", + ), + slit_height=sc.scalar(10.0, unit="cm"), + radius=sc.scalar(30.0, unit="cm"), +) + +pol_disk_chopper = DiskChopper( + frequency=sc.scalar(-14.0, unit="Hz"), + beam_position=sc.scalar(0.0, unit="deg"), + phase=sc.scalar(0.0, unit="deg"), + axle_position=sc.vector(value=[0, 0, 17.0], unit="m"), + slit_begin=sc.array( + dims=["cutout"], + values=np.array([40.0]), + unit="deg", + ), + slit_end=sc.array( + dims=["cutout"], + values=np.array([240.0]), + unit="deg", + ), + slit_height=sc.scalar(10.0, unit="cm"), + radius=sc.scalar(30.0, unit="cm"), +) + +pulse_skipping = DiskChopper( + frequency=sc.scalar(-7.0, unit="Hz"), + beam_position=sc.scalar(0.0, unit="deg"), + phase=sc.scalar(0.0, unit="deg"), + axle_position=sc.vector(value=[0, 0, 30.0], unit="m"), + slit_begin=sc.array( + dims=["cutout"], + values=np.array([40.0]), + unit="deg", + ), + slit_end=sc.array( + dims=["cutout"], + values=np.array([140.0]), + unit="deg", + ), + slit_height=sc.scalar(10.0, unit="cm"), + radius=sc.scalar(30.0, unit="cm"), +) + +wfm_disk_choppers = { + "wfm1": wfm1_disk_chopper, + "wfm2": wfm2_disk_chopper, + "foc1": foc1_disk_chopper, + "foc2": foc2_disk_chopper, + "pol": pol_disk_chopper, +} + +psc_disk_choppers = { + name: DiskChopper( + frequency=ch.frequency, + beam_position=ch.beam_position, + phase=ch.phase, + axle_position=ch.axle_position, + slit_begin=ch.slit_begin[0:1], + slit_end=ch.slit_end[0:1], + slit_height=ch.slit_height[0:1], + radius=ch.radius, + ) + for name, ch in wfm_disk_choppers.items() +} diff --git a/src/scippneutron/tof/tof_simulation.py b/src/scippneutron/tof/tof_simulation.py new file mode 100644 index 000000000..2a072a742 --- /dev/null +++ b/src/scippneutron/tof/tof_simulation.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +import scipp as sc + +from .unwrap import ( + Choppers, + Facility, + NumberOfNeutrons, + SimulationResults, + SimulationSeed, +) + + +def run_tof_simulation( + facility: Facility, + choppers: Choppers, + seed: SimulationSeed, + number_of_neutrons: NumberOfNeutrons, +) -> SimulationResults: + import tof as tof_pkg + + tof_choppers = [ + tof_pkg.Chopper( + frequency=abs(ch.frequency), + direction=tof_pkg.AntiClockwise + if (ch.frequency.value > 0.0) + else tof_pkg.Clockwise, + open=ch.slit_begin, + close=ch.slit_end, + phase=abs(ch.phase), + distance=ch.axle_position.fields.z, + name=name, + ) + for name, ch in choppers.items() + ] + source = tof_pkg.Source(facility=facility, neutrons=number_of_neutrons, seed=seed) + if not tof_choppers: + events = source.data.squeeze() + return SimulationResults( + time_of_arrival=events.coords['time'], + speed=events.coords['speed'], + wavelength=events.coords['wavelength'], + weight=events.data, + distance=0.0 * sc.units.m, + ) + model = tof_pkg.Model(source=source, choppers=tof_choppers) + results = model.run() + # Find name of the furthest chopper in tof_choppers + furthest_chopper = max(tof_choppers, key=lambda c: c.distance) + events = results[furthest_chopper.name].data.squeeze() + events = events[ + ~(events.masks['blocked_by_others'] | events.masks['blocked_by_me']) + ] + return SimulationResults( + time_of_arrival=events.coords['toa'], + speed=events.coords['speed'], + wavelength=events.coords['wavelength'], + weight=events.data, + distance=furthest_chopper.distance, + ) diff --git a/src/scippneutron/tof/unwrap.py b/src/scippneutron/tof/unwrap.py index 30fb3eaca..ea7223b74 100644 --- a/src/scippneutron/tof/unwrap.py +++ b/src/scippneutron/tof/unwrap.py @@ -1,47 +1,109 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 Scipp contributors (https://github.com/scipp) # @author Simon Heybrock +# @author Neil Vaytet """ -This module provides functionality for unwrapping raw frames of neutron time-of-flight -data. - -The module handles standard unwrapping, unwrapping in pulse-skipping mode, and -unwrapping for WFM instruments, as well as combinations of the latter two. The -functions defined here are meant to be used as providers for a Sciline pipeline. See -https://scipp.github.io/sciline/ on how to use Sciline. +Time-of-flight workflow for unwrapping the time of arrival of the neutron at the +detector. +This workflow is used to convert raw detector data with event_time_zero and +event_time_offset coordinates to data with a time-of-flight coordinate. """ +from __future__ import annotations + from collections.abc import Callable, Mapping from dataclasses import dataclass -from typing import NewType +from functools import reduce +from typing import Any, NewType import numpy as np import scipp as sc -from scipp.core.bins import Lookup +from scipp._scipp.core import _bins_no_validate from .._utils import elem_unit -from . import chopper_cascade +from ..chopper import DiskChopper from .to_events import to_events -Choppers = NewType('Choppers', Mapping[str, chopper_cascade.Chopper]) +Facility = NewType('Facility', str) +""" +Facility where the experiment is performed. +""" + +Choppers = NewType('Choppers', Mapping[str, DiskChopper]) """ Choppers used to define the frame parameters. """ -ChopperCascadeFrames = NewType( - 'ChopperCascadeFrames', list[chopper_cascade.FrameSequence] -) +Ltotal = NewType('Ltotal', sc.Variable) """ -Frames of the chopper cascade. +Total length of the flight path from the source to the detector. """ -FrameAtDetector = NewType('FrameAtDetector', chopper_cascade.Frame) +SimulationSeed = NewType('SimulationSeed', int) +""" +Seed for the random number generator used in the simulation. """ -Result of passing the source pulse through the chopper cascade to the detector. -The detector may be a monitor or a detector after scattering off the sample. The frame -bounds are then computed from this. + +NumberOfNeutrons = NewType('NumberOfNeutrons', int) """ +Number of neutrons to use in the simulation. +""" + + +@dataclass +class SimulationResults: + """ + Results of a time-of-flight simulation used to create a lookup table. + """ + + time_of_arrival: sc.Variable + speed: sc.Variable + wavelength: sc.Variable + weight: sc.Variable + distance: sc.Variable + + +@dataclass +class FastestNeutron: + """ + Properties of the fastest neutron in the simulation results. + """ + + time_of_arrival: sc.Variable + speed: sc.Variable + distance: sc.Variable + + +LtotalRange = NewType('LtotalRange', tuple[sc.Variable, sc.Variable]) +""" +Range (min, max) of the total length of the flight path from the source to the detector. +""" + + +DistanceResolution = NewType('DistanceResolution', sc.Variable) +""" +Resolution of the distance axis in the lookup table. +""" + +TimeOfArrivalResolution = NewType('TimeOfArrivalResolution', int | sc.Variable) +""" +Resolution of the time of arrival axis in the lookup table. +Can be an integer (number of bins) or a sc.Variable (bin width). +""" + +TimeOfFlightLookupTable = NewType('TimeOfFlightLookupTable', sc.DataArray) +""" +Lookup table giving time-of-flight as a function of distance and time of arrival. +""" + +MaskedTimeOfFlightLookupTable = NewType('MaskedTimeOfFlightLookupTable', sc.DataArray) +""" +Lookup table giving time-of-flight as a function of distance and time of arrival, with +regions of large uncertainty masked out. +""" + +LookupTableVarianceThreshold = NewType('LookupTableVarianceThreshold', float) FramePeriod = NewType('FramePeriod', sc.Variable) """ @@ -53,9 +115,9 @@ Time of arrival of the neutron at the detector, unwrapped at the pulse period. """ -FrameAtDetectorStartTime = NewType('FrameAtDetectorStartTime', sc.Variable) +PivotTimeAtDetector = NewType('PivotTimeAtDetector', sc.Variable) """ -Time of the start of the frame at the detector. +Pivot time at the detector, i.e., the time of the start of the frame at the detector. """ UnwrappedTimeOfArrivalMinusStartTime = NewType( @@ -74,28 +136,8 @@ modulo the frame period. """ +FrameFoldedTimeOfArrival = NewType('FrameFoldedTimeOfArrival', sc.Variable) -@dataclass -class TimeOfArrivalToTimeOfFlight: - """ """ - - slope: Lookup - intercept: Lookup - - -TofCoord = NewType('TofCoord', sc.Variable) -""" -Tof coordinate computed by the workflow. -""" - -Ltotal = NewType('Ltotal', sc.Variable) -""" -Total distance between the source and the detector(s). - -This is used to propagate the frame to the detector position. This will then yield -detector-dependent frame bounds. This is typically the sum of L1 and L2, except for -monitors. -""" PulsePeriod = NewType('PulsePeriod', sc.Variable) """ @@ -117,18 +159,6 @@ class TimeOfArrivalToTimeOfFlight: Raw detector data loaded from a NeXus file, e.g., NXdetector containing NXevent_data. """ -SourceTimeRange = NewType('SourceTimeRange', tuple[sc.Variable, sc.Variable]) -""" -Time range of the source pulse, used for computing frame bounds. -""" - -SourceWavelengthRange = NewType( - 'SourceWavelengthRange', tuple[sc.Variable, sc.Variable] -) -""" -Wavelength range of the source pulse, used for computing frame bounds. -""" - TofData = NewType('TofData', sc.DataArray) """ Detector data with time-of-flight coordinate. @@ -140,6 +170,20 @@ class TimeOfArrivalToTimeOfFlight: """ +def pulse_period_from_source(facility: Facility) -> PulsePeriod: + """ + Return the period of the source pulses, i.e., time between consecutive pulse starts. + + Parameters + ---------- + facility: + Facility where the experiment is performed (used to determine the source pulse + parameters). + """ + facilities = {"ess": sc.scalar(14.0, unit='Hz')} + return PulsePeriod(1.0 / facilities[facility]) + + def frame_period(pulse_period: PulsePeriod, pulse_stride: PulseStride) -> FramePeriod: """ Return the period of a frame, which is defined by the pulse period times the pulse @@ -156,104 +200,157 @@ def frame_period(pulse_period: PulsePeriod, pulse_stride: PulseStride) -> FrameP return FramePeriod(pulse_period * pulse_stride) -def chopper_cascade_frames( - source_wavelength_range: SourceWavelengthRange, - source_time_range: SourceTimeRange, - choppers: Choppers, - pulse_stride: PulseStride, - pulse_period: PulsePeriod, -) -> ChopperCascadeFrames: +def extract_ltotal(da: RawData) -> Ltotal: """ - Return the frames of the chopper cascade. - This is the result of propagating the source pulse through the chopper cascade. + Extract the total length of the flight path from the source to the detector from the + detector data. - In the case of pulse-skipping, the frames are computed for each pulse in the stride, - to make sure that we include cases where e.g. the first pulse in the stride is - skipped, but the second is not. + Parameters + ---------- + da: + Raw detector data loaded from a NeXus file, e.g., NXdetector containing + NXevent_data. + """ + return Ltotal(da.coords["Ltotal"]) + + +def compute_tof_lookup_table( + simulation: SimulationResults, + ltotal_range: LtotalRange, + distance_resolution: DistanceResolution, + toa_resolution: TimeOfArrivalResolution, +) -> TimeOfFlightLookupTable: + distance_unit = 'm' + res = distance_resolution.to(unit=distance_unit) + simulation_distance = simulation.distance.to(unit=distance_unit) + + # We need to bin the data below, to compute the weighted mean of the wavelength. + # This results in data with bin edges. + # However, the 2d interpolator expects bin centers. + # We want to give the 2d interpolator a table that covers the requested range, + # hence we need to extend the range by half a resolution in each direction. + min_dist, max_dist = [ + x.to(unit=distance_unit) - simulation_distance for x in ltotal_range + ] + min_dist, max_dist = min_dist - 0.5 * res, max_dist + 0.5 * res + + dist_edges = sc.array( + dims=['distance'], + values=np.arange( + min_dist.value, np.nextafter(max_dist.value, np.inf), res.value + ), + unit=distance_unit, + ) + distances = sc.midpoints(dist_edges) + + time_unit = simulation.time_of_arrival.unit + toas = simulation.time_of_arrival + (distances / simulation.speed).to( + unit=time_unit, copy=False + ) + + data = sc.DataArray( + data=sc.broadcast(simulation.weight, sizes=toas.sizes).flatten(to='event'), + coords={ + 'toa': toas.flatten(to='event'), + 'wavelength': sc.broadcast(simulation.wavelength, sizes=toas.sizes).flatten( + to='event' + ), + 'distance': sc.broadcast(distances, sizes=toas.sizes).flatten(to='event'), + }, + ) + + binned = data.bin(distance=dist_edges, toa=toa_resolution) + # Weighted mean of wavelength inside each bin + wavelength = ( + binned.bins.data * binned.bins.coords['wavelength'] + ).bins.sum() / binned.bins.sum() + # Compute the variance of the wavelength to track regions with large uncertainty + variance = ( + binned.bins.data * (binned.bins.coords['wavelength'] - wavelength) ** 2 + ).bins.sum() / binned.bins.sum() + + # Need to add the simulation distance to the distance coordinate + wavelength.coords['distance'] = wavelength.coords['distance'] + simulation_distance + h = sc.constants.h + m_n = sc.constants.m_n + velocity = (h / (wavelength * m_n)).to(unit='m/s') + timeofflight = (sc.midpoints(wavelength.coords['distance'])) / velocity + out = timeofflight.to(unit=time_unit, copy=False) + # Include the variances computed above + out.variances = variance.values + + # Convert coordinates to midpoints + out.coords['toa'] = sc.midpoints(out.coords['toa']) + out.coords['distance'] = sc.midpoints(out.coords['distance']) + + return TimeOfFlightLookupTable(out) + + +def masked_tof_lookup_table( + tof_lookup: TimeOfFlightLookupTable, + variance_threshold: LookupTableVarianceThreshold, +) -> MaskedTimeOfFlightLookupTable: + """ + Mask regions of the lookup table where the variance of the projected time-of-flight + is larger than a given threshold. Parameters ---------- - source_wavelength_range: - Wavelength range of the source pulse. - source_time_range: - Time range of the source pulse. - choppers: - Choppers used to define the frame parameters. - pulse_stride: - Stride of used pulses. Usually 1, but may be a small integer when - pulse-skipping. - pulse_period: - Period of the source pulses, i.e., time between consecutive pulse starts. + tof_lookup: + Lookup table giving time-of-flight as a function of distance and + time-of-arrival. + variance_threshold: + Threshold for the variance of the projected time-of-flight above which regions + are masked. """ - out = [] - for i in range(pulse_stride): - offset = (pulse_period * i).to(unit=source_time_range[0].unit, copy=False) - frames = chopper_cascade.FrameSequence.from_source_pulse( - time_min=source_time_range[0] + offset, - time_max=source_time_range[-1] + offset, - wavelength_min=source_wavelength_range[0], - wavelength_max=source_wavelength_range[-1], - ) - chopped = frames.chop(choppers.values()) - for f in chopped: - for sf in f.subframes: - sf.time -= offset.to(unit=sf.time.unit, copy=False) - out.append(chopped) - return ChopperCascadeFrames(out) + variances = sc.variances(tof_lookup.data) + mask = variances > sc.scalar(variance_threshold, unit=variances.unit) + out = tof_lookup.copy(deep=False) + if mask.any(): + out.masks["uncertain"] = mask + return MaskedTimeOfFlightLookupTable(out) -def frame_at_detector( - frames: ChopperCascadeFrames, - ltotal: Ltotal, - period: FramePeriod, -) -> FrameAtDetector: +def find_fastest_neutron(simulation: SimulationResults) -> FastestNeutron: + """ + Find the fastest neutron in the simulation results. """ - Return the frame at the detector. + ind = np.argmax(simulation.speed.values) + return FastestNeutron( + time_of_arrival=simulation.time_of_arrival[ind], + speed=simulation.speed[ind], + distance=simulation.distance, + ) - This is the result of propagating the source pulse through the chopper cascade to - the detector. The detector may be a monitor or a detector after scattering off the - sample. The frame bounds are then computed from this. - It is assumed that the opening and closing times of the input choppers have been - setup correctly. +def pivot_time_at_detector( + fastest_neutron: FastestNeutron, ltotal: Ltotal +) -> PivotTimeAtDetector: + """ + Compute the pivot time at the detector, i.e., the time of the start of the frame at + the detector. + The assumption here is that the fastest neutron in the simulation results is the one + that arrives at the detector first. + One could have an edge case where a slightly slower neutron which is born earlier + could arrive at the detector first, but this edge case is most probably uncommon, + and the difference in arrival times is likely to be small. Parameters ---------- - frames: - Frames of the chopper cascade. + fastest_neutron: + Properties of the fastest neutron in the simulation results. ltotal: - Total distance between the source and the detector(s). - period: - Period of the frame, i.e., time between the start of two consecutive frames. + Total length of the flight path from the source to the detector. """ - - # In the case of pulse-skipping, only one of the frames should have subframes (the - # others should be empty). - at_detector = [] - for f in frames: - propagated = f[-1].propagate_to(ltotal) - if len(propagated.subframes) > 0: - at_detector.append(propagated) - if len(at_detector) == 0: - raise ValueError("FrameAtDetector: No frames with subframes found.") - if len(at_detector) > 1: - raise ValueError("FrameAtDetector: Multiple frames with subframes found.") - at_detector = at_detector[0] - - # Check that the frame bounds do not span a range larger than the frame period. - # This would indicate that the chopper phases are not set correctly. - bounds = at_detector.bounds()['time'] - diff = (bounds.max('bound') - bounds.min('bound')).flatten(to='x') - if any(diff > period.to(unit=diff.unit, copy=False)): - raise ValueError( - "Frames are overlapping: Computed frame bounds " - f"{bounds} = {diff.max()} are larger than frame period {period}." - ) - return FrameAtDetector(at_detector) + dist = ltotal - fastest_neutron.distance.to(unit=ltotal.unit) + toa = fastest_neutron.time_of_arrival + (dist / fastest_neutron.speed).to( + unit=fastest_neutron.time_of_arrival.unit, copy=False + ) + return PivotTimeAtDetector(toa) def unwrapped_time_of_arrival( - da: RawData, offset: PulseStrideOffset, period: PulsePeriod + da: RawData, offset: PulseStrideOffset, pulse_period: PulsePeriod ) -> UnwrappedTimeOfArrival: """ Compute the unwrapped time of arrival of the neutron at the detector. @@ -268,12 +365,14 @@ def unwrapped_time_of_arrival( Integer offset of the first pulse in the stride (typically zero unless we are using pulse-skipping and the events do not begin with the first pulse in the stride). - period: + pulse_period: Period of the source pulses, i.e., time between consecutive pulse starts. """ if da.bins is None: - # Canonical name in NXmonitor - toa = da.coords['time_of_flight'] + # 'time_of_flight' is the canonical name in NXmonitor, but in some files, it + # may be called 'tof'. + key = next(iter(set(da.coords.keys()) & {'time_of_flight', 'tof'})) + toa = da.coords[key] else: # To unwrap the time of arrival, we want to add the event_time_zero to the # event_time_offset. However, we do not really care about the exact datetimes, @@ -285,25 +384,13 @@ def unwrapped_time_of_arrival( toa = ( coord + time_zero.to(dtype=float, unit=unit, copy=False) - - (offset * period).to(unit=unit, copy=False) + - (offset * pulse_period).to(unit=unit, copy=False) ) return UnwrappedTimeOfArrival(toa) -def frame_at_detector_start_time(frame: FrameAtDetector) -> FrameAtDetectorStartTime: - """ - Compute the start time of the frame at the detector. - - Parameters - ---------- - frame: - Frame at the detector - """ - return FrameAtDetectorStartTime(frame.bounds()['time']['bound', 0]) - - def unwrapped_time_of_arrival_minus_frame_start_time( - toa: UnwrappedTimeOfArrival, start_time: FrameAtDetectorStartTime + toa: UnwrappedTimeOfArrival, pivot_time: PivotTimeAtDetector ) -> UnwrappedTimeOfArrivalMinusStartTime: """ Compute the time of arrival of the neutron at the detector, unwrapped at the pulse @@ -315,12 +402,13 @@ def unwrapped_time_of_arrival_minus_frame_start_time( ---------- toa: Time of arrival of the neutron at the detector, unwrapped at the pulse period. - start_time: - Time of the start of the frame at the detector. + pivot_time: + Pivot time at the detector, i.e., the time of the start of the frame at the + detector. """ # Order of operation to preserve dimension order return UnwrappedTimeOfArrivalMinusStartTime( - -start_time.to(unit=elem_unit(toa), copy=False) + toa + -pivot_time.to(unit=elem_unit(toa), copy=False) + toa ) @@ -346,158 +434,68 @@ def time_of_arrival_minus_start_time_modulo_period( ) -def _approximate_polygon_with_line( - x0: sc.Variable, y0: sc.Variable, dim: str -) -> tuple[sc.Variable, sc.Variable]: - """ - Approximate a polygon defined by the vertices of the subframe with a straight line. - Compute the slope and intercept of the line that minimizes the integrated squared - error over the polygon (i.e. taking the area of the polygon into account, as opposed - to just computing a least-squares fit of the vertices). - The method is described at - https://mathproblems123.wordpress.com/2022/09/13/integrating-polynomials-on-polygons/ - - Parameters - ---------- - x0: - x coordinates of the polygon vertices. - y0: - y coordinates of the polygon vertices. - dim: - Dimension along which the vertices are defined. - """ - iv = x0.dims.index(dim) - x1 = sc.array(dims=x0.dims, values=np.roll(x0.values, 1, axis=iv), unit=x0.unit) - y1 = sc.array(dims=y0.dims, values=np.roll(y0.values, 1, axis=iv), unit=x0.unit) - - x0y1 = x0 * y1 - x1y0 = x1 * y0 - x0y1_x1y0 = x0y1 - x1y0 - - A = ((x0y1_x1y0) / 2).sum(dim) - x = ((x0 + x1) * (x0y1_x1y0) / 6).sum(dim) - y = ((y0 + y1) * (x0y1_x1y0) / 6).sum(dim) - xy = ((x0y1_x1y0) * (2 * x0 * y0 + x0y1 + x1y0 + 2 * x1 * y1) / 24).sum(dim) - xx = ((x0y1_x1y0) * (x0**2 + x0 * x1 + x1**2) / 12).sum(dim) - - a = (xy - x * y / A) / (xx - x**2 / A) - b = (y / A) - a * (x / A) - return a, b - - -def relation_between_time_of_arrival_and_tof( - frame: FrameAtDetector, frame_start: FrameAtDetectorStartTime, ltotal: Ltotal -) -> TimeOfArrivalToTimeOfFlight: - """ - Compute the slope and intercept of a linear relationship between time-of-arrival - and tof, which can be used to create lookup tables which can give the - time-of-flight from the time-of-arrival. - - We take the polygons that define the subframes, given by the chopper cascade, and - approximate them by straight lines. - - Parameters - ---------- - frame: - Frame at the detector. - frame_start: - Time of the start of the frame at the detector. - ltotal: - Total distance between the source and the detector(s). - """ - slopes = [] - intercepts = [] - subframes = sorted(frame.subframes, key=lambda x: x.start_time.min()) - edges = [] - - for sf in subframes: - edges.extend([sf.start_time, sf.end_time]) - a, b = _approximate_polygon_with_line( - x0=sf.time - frame_start, # Horizontal axis is time-of-arrival - y0=( - ltotal * chopper_cascade.wavelength_to_inverse_velocity(sf.wavelength) - ).to(unit=sf.time.unit, copy=False), # Vertical axis is time-of-flight - dim='vertex', - ) - slopes.append(a) - intercepts.append(b) - - # It is sometimes possible that there is time overlap between subframes. - # This is not desired in a chopper cascade but can sometimes happen if the phases - # are not set correctly. Overlap would mean that the start of the next subframe is - # before the end of the previous subframe. - # We sort the edges to make sure that the lookup table is sorted. This creates a - # gap between the overlapping subframes, and discards any neutrons (gives them a - # NaN tof) that fall into the gap, which is the desired behaviour because we - # cannot determine the correct tof for them. - edges = ( - sc.sort(sc.concat(edges, 'subframe').transpose().copy(), 'subframe') - - frame_start - ) - sizes = frame_start.sizes | {'subframe': 2 * len(subframes) - 1} - keys = list(sizes.keys()) - - data = sc.full(sizes=sizes, value=np.nan) - data['subframe', ::2] = sc.concat(slopes, 'subframe').transpose(keys) - da_slope = sc.DataArray(data=data, coords={'subframe': edges}) - - data = sc.full(sizes=sizes, value=np.nan, unit=sf.time.unit) - data['subframe', ::2] = sc.concat(intercepts, 'subframe').transpose(keys) - da_intercept = sc.DataArray(data=data, coords={'subframe': edges}) - - return TimeOfArrivalToTimeOfFlight(slope=da_slope, intercept=da_intercept) - - -def time_of_flight_from_lookup( +def time_of_arrival_folded_by_frame( toa: TimeOfArrivalMinusStartTimeModuloPeriod, - toa_to_tof: TimeOfArrivalToTimeOfFlight, -) -> TofCoord: + pivot_time: PivotTimeAtDetector, +) -> FrameFoldedTimeOfArrival: """ - Compute the time-of-flight from the time-of-arrival. - Lookup tables to convert time-of-arrival to time-of-flight are created internally. + The time of arrival of the neutron at the detector, folded by the frame period. Parameters ---------- toa: Time of arrival of the neutron at the detector, unwrapped at the pulse period, minus the start time of the frame, modulo the frame period. - toa_to_tof: - Conversion from-time-of arrival to time-of-flight. + pivot_time: + Pivot time at the detector, i.e., the time of the start of the frame at the + detector. """ - # Ensure unit consistency - subframe_edges = toa_to_tof.slope.coords['subframe'].to( - unit=elem_unit(toa), copy=False - ) - # Both slope and intercepts should have the same subframe edges - toa_to_tof.slope.coords['subframe'] = subframe_edges - toa_to_tof.intercept.coords['subframe'] = subframe_edges - toa_to_tof.intercept.data = toa_to_tof.intercept.data.to( - unit=elem_unit(toa), copy=False + return FrameFoldedTimeOfArrival( + toa + pivot_time.to(unit=elem_unit(toa), copy=False) ) - slope = sc.lookup(toa_to_tof.slope, dim='subframe')[toa] - intercept = sc.lookup(toa_to_tof.intercept, dim='subframe')[toa] - return TofCoord(slope * toa + intercept) +def time_of_flight_data( + da: RawData, + lookup: MaskedTimeOfFlightLookupTable, + ltotal: Ltotal, + toas: FrameFoldedTimeOfArrival, +) -> TofData: + from scipy.interpolate import RegularGridInterpolator + + lookup_values = lookup.data.to(unit=elem_unit(toas), copy=False).values + # Merge all masks into a single mask + if lookup.masks: + one_mask = reduce(lambda a, b: a | b, lookup.masks.values()).values + # Set masked values to NaN + lookup_values[one_mask] = np.nan + + f = RegularGridInterpolator( + ( + lookup.coords['toa'].to(unit=elem_unit(toas), copy=False).values, + lookup.coords['distance'].to(unit=ltotal.unit, copy=False).values, + ), + lookup_values.T, + method='linear', + bounds_error=False, + ) -def time_of_flight_data(da: RawData, tof: TofCoord) -> TofData: - """ - Add the time-of-flight coordinate to the data. + if da.bins is not None: + ltotal = sc.bins_like(toas, ltotal).bins.constituents['data'] + toas = toas.bins.constituents['data'] + + tofs = sc.array( + dims=toas.dims, values=f((toas.values, ltotal.values)), unit=elem_unit(toas) + ) - Parameters - ---------- - da: - Raw detector data loaded from a NeXus file, e.g., NXdetector containing - NXevent_data. - tof: - Time-of-flight coordinate. - """ out = da.copy(deep=False) - if tof.bins is not None: - out.data = sc.bins(**out.bins.constituents) - out.bins.coords['tof'] = tof + if out.bins is not None: + parts = out.bins.constituents + out.data = sc.bins(**parts) + parts['data'] = tofs + out.bins.coords['tof'] = _bins_no_validate(**parts) else: - out.coords['tof'] = tof + out.coords['tof'] = tofs return TofData(out) @@ -537,23 +535,111 @@ def re_histogram_tof_data(da: TofData) -> ReHistogrammedTofData: return ReHistogrammedTofData(rehist) -def providers() -> tuple[Callable, ...]: +def default_parameters() -> dict: + """ + Default parameters of the time-of-flight workflow. + """ + return { + PulseStride: 1, + PulseStrideOffset: 0, + DistanceResolution: sc.scalar(1.0, unit='cm'), + TimeOfArrivalResolution: 500, + LookupTableVarianceThreshold: 1.0e-2, + SimulationSeed: 1234, + NumberOfNeutrons: 1_000_000, + } + + +def _providers() -> tuple[Callable]: + """ + Base providers of the time-of-flight workflow. + """ return ( - chopper_cascade_frames, - frame_at_detector, + compute_tof_lookup_table, + extract_ltotal, + find_fastest_neutron, frame_period, - unwrapped_time_of_arrival, - frame_at_detector_start_time, - unwrapped_time_of_arrival_minus_frame_start_time, + masked_tof_lookup_table, + pivot_time_at_detector, + pulse_period_from_source, + time_of_arrival_folded_by_frame, time_of_arrival_minus_start_time_modulo_period, - relation_between_time_of_arrival_and_tof, - time_of_flight_from_lookup, time_of_flight_data, + unwrapped_time_of_arrival, + unwrapped_time_of_arrival_minus_frame_start_time, ) -def params() -> dict: - return { - PulseStride: 1, - PulseStrideOffset: 0, - } +def standard_providers() -> tuple[Callable]: + """ + Standard providers of the time-of-flight workflow, using the ``tof`` library to + build the time-of-arrival to time-of-flight lookup table. + """ + from .tof_simulation import run_tof_simulation + + return (*_providers(), run_tof_simulation) + + +class TofWorkflow: + """ + Helper class to build a time-of-flight workflow and cache the expensive part of + the computation: running the simulation and building the lookup table. + """ + + def __init__( + self, + choppers, + facility, + ltotal_range, + pulse_stride=None, + pulse_stride_offset=None, + distance_resolution=None, + toa_resolution=None, + variance_threshold=None, + seed=None, + number_of_neutrons=None, + ): + import sciline as sl + + self.pipeline = sl.Pipeline(standard_providers()) + self.pipeline[Facility] = facility + self.pipeline[Choppers] = choppers + self.pipeline[LtotalRange] = ltotal_range + + params = default_parameters() + self.pipeline[PulseStride] = pulse_stride or params[PulseStride] + self.pipeline[PulseStrideOffset] = ( + pulse_stride_offset or params[PulseStrideOffset] + ) + self.pipeline[DistanceResolution] = ( + distance_resolution or params[DistanceResolution] + ) + self.pipeline[TimeOfArrivalResolution] = ( + toa_resolution or params[TimeOfArrivalResolution] + ) + self.pipeline[LookupTableVarianceThreshold] = ( + variance_threshold or params[LookupTableVarianceThreshold] + ) + self.pipeline[SimulationSeed] = seed or params[SimulationSeed] + self.pipeline[NumberOfNeutrons] = number_of_neutrons or params[NumberOfNeutrons] + + def __getitem__(self, key): + return self.pipeline[key] + + def __setitem__(self, key, value): + self.pipeline[key] = value + + def persist(self) -> None: + for t in (SimulationResults, MaskedTimeOfFlightLookupTable, FastestNeutron): + self.pipeline[t] = self.pipeline.compute(t) + + def compute(self, *args, **kwargs) -> Any: + return self.pipeline.compute(*args, **kwargs) + + def visualize(self, *args, **kwargs) -> Any: + return self.pipeline.visualize(*args, **kwargs) + + def copy(self) -> TofWorkflow: + out = self.__class__(choppers=None, facility=None, ltotal_range=None) + out.pipeline = self.pipeline.copy() + return out diff --git a/tests/tof/unwrap_test.py b/tests/tof/unwrap_test.py index b80a9de71..9b7ac382c 100644 --- a/tests/tof/unwrap_test.py +++ b/tests/tof/unwrap_test.py @@ -13,30 +13,6 @@ sl = pytest.importorskip('sciline') -@pytest.fixture -def ess_10s_14Hz() -> fakes.FakeSource: - return fakes.FakeSource( - frequency=sc.scalar(14.0, unit='Hz'), run_length=sc.scalar(10.0, unit='s') - ) - - -@pytest.fixture -def ess_10s_7Hz() -> fakes.FakeSource: - return fakes.FakeSource( - frequency=sc.scalar(7.0, unit='Hz'), run_length=sc.scalar(10.0, unit='s') - ) - - -@pytest.fixture -def ess_pulse() -> fakes.FakePulse: - return fakes.FakePulse( - time_min=sc.scalar(0.0, unit='ms'), - time_max=sc.scalar(3.0, unit='ms'), - wavelength_min=sc.scalar(0.1, unit='angstrom'), - wavelength_max=sc.scalar(10.0, unit='angstrom'), - ) - - def test_frame_period_is_pulse_period_if_not_pulse_skipping() -> None: pl = sl.Pipeline(unwrap.providers()) period = sc.scalar(123.0, unit='ms') @@ -54,133 +30,87 @@ def test_frame_period_is_multiple_pulse_period_if_pulse_skipping(stride) -> None assert_identical(pl.compute(unwrap.FramePeriod), stride * period) -def test_unwrap_with_no_choppers(ess_10s_14Hz, ess_pulse) -> None: +def test_unwrap_with_no_choppers() -> None: # At this small distance the frames are not overlapping (with the given wavelength # range), despite not using any choppers. distance = sc.scalar(10.0, unit='m') - beamline = fakes.FakeBeamline( - source=ess_10s_14Hz, - pulse=ess_pulse, - choppers={}, # no choppers - monitors={'monitor': distance}, - detectors={}, + + beamline = fakes.FakeBeamlineEss( + choppers={}, + monitors={"detector": distance}, + run_length=sc.scalar(1 / 14, unit="s") * 4, + events_per_pulse=100_000, ) - mon, ref = beamline.get_monitor('monitor') + + mon, ref = beamline.get_monitor('detector') pl = sl.Pipeline(unwrap.providers(), params=unwrap.params()) + pl[unwrap.Facility] = 'ess' pl[unwrap.RawData] = mon - pl[unwrap.PulsePeriod] = beamline._source.pulse_period - pl[unwrap.SourceTimeRange] = ess_pulse.time_min, ess_pulse.time_max - pl[unwrap.SourceWavelengthRange] = ( - ess_pulse.wavelength_min, - ess_pulse.wavelength_max, - ) pl[unwrap.Choppers] = {} pl[unwrap.Ltotal] = distance - result = pl.compute(unwrap.TofData).bins.concat().value + tofs = pl.compute(unwrap.TofData) + + # Convert to wavelength + graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} + wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value ref = ref.bins.concat().value - # Ensure that the bounds are close - res_tof = result.coords['tof'] - ref_tof = ref.coords['tof'] - delta = ref_tof.max() - ref_tof.min() - assert sc.abs((res_tof.min() - ref_tof.min()) / delta) < sc.scalar(1e-02) - assert sc.abs((res_tof.max() - ref_tof.max()) / delta) < sc.scalar(1e-02) - - # Because the bounds are not the same, using the same bins for bot results would - # lead to large differences at the edges. So we pick the most narrow range to - # histogram. - bins = sc.linspace( - 'tof', - max(res_tof.min(), ref_tof.min()), - min(res_tof.max(), ref_tof.max()), - num=501, + diff = abs( + (wavs.coords['wavelength'] - ref.coords['wavelength']) + / ref.coords['wavelength'] ) - - ref_hist = ref.hist(tof=bins) - res_hist = result.hist(tof=bins) - diff = ((res_hist - ref_hist) / ref_hist.max()).data - assert sc.abs(diff).max() < sc.scalar(1.0e-1) + # Most errors should be small + assert np.nanpercentile(diff.values, 96) < 1.0 -def test_unwrap_with_frame_overlap_raises(ess_10s_14Hz, ess_pulse) -> None: - distance = sc.scalar(46.0, unit='m') - beamline = fakes.FakeBeamline( - source=ess_10s_14Hz, - pulse=ess_pulse, - choppers={}, # no choppers - monitors={'monitor': distance}, - detectors={}, +# At 80m, event_time_offset does not wrap around (all events are within the same pulse). +# At 85m, event_time_offset wraps around. +@pytest.mark.parametrize('dist', [80.0, 85.0]) +def test_standard_unwrap(dist) -> None: + distance = sc.scalar(dist, unit='m') + beamline = fakes.FakeBeamlineEss( + choppers=fakes.psc_disk_choppers, + monitors={"detector": distance}, + run_length=sc.scalar(1 / 14, unit="s") * 4, + events_per_pulse=100_000, ) - mon, _ = beamline.get_monitor('monitor') + mon, ref = beamline.get_monitor('detector') pl = sl.Pipeline(unwrap.providers(), params=unwrap.params()) + pl[unwrap.Facility] = 'ess' pl[unwrap.RawData] = mon - pl[unwrap.PulsePeriod] = beamline._source.pulse_period - pl[unwrap.SourceTimeRange] = ess_pulse.time_min, ess_pulse.time_max - pl[unwrap.SourceWavelengthRange] = ( - ess_pulse.wavelength_min, - ess_pulse.wavelength_max, - ) - pl[unwrap.Choppers] = {} + pl[unwrap.Choppers] = fakes.psc_disk_choppers pl[unwrap.Ltotal] = distance - with pytest.raises(ValueError, match='Frames are overlapping'): - pl.compute(unwrap.TofData) - -# At 44m, event_time_offset does not wrap around (all events are within the same pulse). -# At 47m, event_time_offset wraps around. -@pytest.mark.parametrize('dist', [44.0, 47.0]) -def test_standard_unwrap(ess_10s_14Hz, ess_pulse, dist) -> None: - distance = sc.scalar(dist, unit='m') - beamline = fakes.FakeBeamline( - source=ess_10s_14Hz, - pulse=ess_pulse, - choppers=fakes.psc_choppers, - monitors={'monitor': distance}, - detectors={}, - time_of_flight_origin='psc1', - ) - mon, ref = beamline.get_monitor('monitor') + tofs = pl.compute(unwrap.TofData) - pl = sl.Pipeline(unwrap.providers(), params=unwrap.params()) - pl[unwrap.RawData] = mon - pl[unwrap.PulsePeriod] = beamline._source.pulse_period - pl[unwrap.SourceTimeRange] = ess_pulse.time_min, ess_pulse.time_max - pl[unwrap.SourceWavelengthRange] = ( - ess_pulse.wavelength_min, - ess_pulse.wavelength_max, - ) - pl[unwrap.Choppers] = fakes.psc_choppers - pl[unwrap.Ltotal] = distance - result = pl.compute(unwrap.TofData) + # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - ref_wav = ref.transform_coords('wavelength', graph=graph).bins.concat().value - result.coords['Ltotal'] = distance - result_wav = result.transform_coords('wavelength', graph=graph).bins.concat().value - - assert sc.allclose( - result_wav.coords['wavelength'], - ref_wav.coords['wavelength'], - rtol=sc.scalar(1e-02), + wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value + ref = ref.bins.concat().value + + diff = abs( + (wavs.coords['wavelength'] - ref.coords['wavelength']) + / ref.coords['wavelength'] ) + # All errors should be small + assert np.nanpercentile(diff.values, 100) < 0.01 -# At 44m, event_time_offset does not wrap around (all events are within the same pulse). -# At 47m, event_time_offset wraps around. -@pytest.mark.parametrize('dist', [44.0, 47.0]) -def test_standard_unwrap_histogram_mode(ess_10s_14Hz, ess_pulse, dist) -> None: +# At 80m, event_time_offset does not wrap around (all events are within the same pulse). +# At 85m, event_time_offset wraps around. +@pytest.mark.parametrize('dist', [80.0, 85.0]) +def test_standard_unwrap_histogram_mode(dist) -> None: distance = sc.scalar(dist, unit='m') - beamline = fakes.FakeBeamline( - source=ess_10s_14Hz, - pulse=ess_pulse, - choppers=fakes.psc_choppers, - monitors={'monitor': distance}, - detectors={}, - time_of_flight_origin='psc1', + beamline = fakes.FakeBeamlineEss( + choppers=fakes.psc_disk_choppers, + monitors={"detector": distance}, + run_length=sc.scalar(1 / 14, unit="s") * 4, + events_per_pulse=100_000, ) - mon, ref = beamline.get_monitor('monitor') + mon, ref = beamline.get_monitor('detector') mon = ( mon.hist( event_time_offset=sc.linspace( @@ -194,214 +124,139 @@ def test_standard_unwrap_histogram_mode(ess_10s_14Hz, ess_pulse, dist) -> None: pl = sl.Pipeline( (*unwrap.providers(), unwrap.re_histogram_tof_data), params=unwrap.params() ) + pl[unwrap.Facility] = 'ess' pl[unwrap.RawData] = mon - pl[unwrap.PulsePeriod] = beamline._source.pulse_period - pl[unwrap.SourceTimeRange] = ess_pulse.time_min, ess_pulse.time_max - pl[unwrap.SourceWavelengthRange] = ( - ess_pulse.wavelength_min, - ess_pulse.wavelength_max, - ) - pl[unwrap.Choppers] = fakes.psc_choppers + pl[unwrap.Choppers] = fakes.psc_disk_choppers pl[unwrap.Ltotal] = distance - result = pl.compute(unwrap.ReHistogrammedTofData) + tofs = pl.compute(unwrap.ReHistogrammedTofData) graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - result.coords['Ltotal'] = distance - result_wav = result.transform_coords('wavelength', graph=graph) - ref_wav = ( - ref.transform_coords('wavelength', graph=graph) - .bins.concat() - .value.hist(wavelength=result_wav.coords['wavelength']) - ) - diff = (result_wav - ref_wav) / ref_wav - # There are outliers in the diff because the bins don't cover the exact same range, - # and the bins on the edges have high counts in one data array and are empty in the - # other. - # Instead, we check that 96% of the data has an error below 0.1. - x = np.abs(diff.data.values) - assert np.percentile(x[np.isfinite(x)], 96.0) < 0.1 + wavs = tofs.transform_coords('wavelength', graph=graph) + ref = ref.bins.concat().value.hist(wavelength=wavs.coords['wavelength']) + # We divide by the maximum to avoid large relative differences at the edges of the + # frames where the counts are low. + diff = (wavs - ref) / ref.max() + assert np.nanpercentile(diff.values, 96.0) < 0.3 -@pytest.mark.parametrize('dist', [44.0, 47.0]) -def test_pulse_skipping_unwrap(dist) -> None: - distance = sc.scalar(dist, unit='m') - choppers = fakes.psc_choppers.copy() +def test_pulse_skipping_unwrap() -> None: + distance = sc.scalar(100.0, unit='m') + choppers = fakes.psc_disk_choppers.copy() choppers['pulse_skipping'] = fakes.pulse_skipping - # We use the ESS fake here because the fake beamline does not support choppers - # rotating at 7 Hz. beamline = fakes.FakeBeamlineEss( choppers=choppers, - monitors={'monitor': distance}, + monitors={'detector': distance}, run_length=sc.scalar(1.0, unit='s'), events_per_pulse=100_000, ) - mon, ref = beamline.get_monitor('monitor') + mon, ref = beamline.get_monitor('detector') pl = sl.Pipeline(unwrap.providers(), params=unwrap.params()) + pl[unwrap.Facility] = 'ess' pl[unwrap.RawData] = mon - pl[unwrap.PulsePeriod] = 1.0 / beamline.source.frequency + pl[unwrap.Choppers] = choppers + pl[unwrap.Ltotal] = distance pl[unwrap.PulseStride] = 2 - one_pulse = beamline.source.data['pulse', 0] - pl[unwrap.SourceTimeRange] = ( - one_pulse.coords['time'].min(), - one_pulse.coords['time'].max(), - ) - pl[unwrap.SourceWavelengthRange] = ( - one_pulse.coords['wavelength'].min(), - one_pulse.coords['wavelength'].max(), - ) + tofs = pl.compute(unwrap.TofData) - pl[unwrap.Choppers] = choppers - pl[unwrap.Ltotal] = distance - result = pl.compute(unwrap.TofData) + # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - ref_wav = ref.transform_coords('wavelength', graph=graph).bins.concat().value - result.coords['Ltotal'] = distance - result_wav = result.transform_coords('wavelength', graph=graph).bins.concat().value - - assert sc.allclose( - result_wav.coords['wavelength'], - ref_wav.coords['wavelength'], - rtol=sc.scalar(1e-02), - ) - + wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value + ref = ref.bins.concat().value -@pytest.mark.parametrize('dist', [44.0, 47.0]) -def test_pulse_skipping_with_180deg_phase_unwrap(dist) -> None: - from copy import copy + diff = abs( + (wavs.coords['wavelength'] - ref.coords['wavelength']) + / ref.coords['wavelength'] + ) + # All errors should be small + assert np.nanpercentile(diff.values, 100) < 0.01 - distance = sc.scalar(dist, unit='m') - # We will add 180 deg to the phase of the pulse-skipping chopper. This means that - # the first pulse will be blocked and the second one will be transmitted. - # When finding the FrameAtDetector, we need to propagate the second pulse through - # the cascade as well. For that, we need to spin the choppers by an additional - # rotation. - period = 1.0 / sc.scalar(14.0, unit='Hz') - choppers = sc.DataGroup() - for key, value in fakes.psc_choppers.items(): - ch = copy(value) - ch.time_open = sc.concat( - [ch.time_open, ch.time_open + period], ch.time_open.dim - ) - ch.time_close = sc.concat( - [ch.time_close, ch.time_close + period], ch.time_close.dim - ) - choppers[key] = ch - - choppers['pulse_skipping'] = copy(fakes.pulse_skipping) - # Add 180 deg to the phase of the pulse-skipping chopper (same as offsetting the - # time by one period). - choppers['pulse_skipping'].time_open = choppers['pulse_skipping'].time_open + period - choppers['pulse_skipping'].time_close = ( - choppers['pulse_skipping'].time_close + period - ) +def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse() -> None: + distance = sc.scalar(150.0, unit='m') + choppers = fakes.psc_disk_choppers.copy() + choppers['pulse_skipping'] = fakes.pulse_skipping - # We use the ESS fake here because the fake beamline does not support choppers - # rotating at 7 Hz. beamline = fakes.FakeBeamlineEss( choppers=choppers, - monitors={'monitor': distance}, + monitors={'detector': distance}, run_length=sc.scalar(1.0, unit='s'), events_per_pulse=100_000, ) - mon, ref = beamline.get_monitor('monitor') + mon, ref = beamline.get_monitor('detector') pl = sl.Pipeline(unwrap.providers(), params=unwrap.params()) + pl[unwrap.Facility] = 'ess' pl[unwrap.RawData] = mon - pl[unwrap.PulsePeriod] = 1.0 / beamline.source.frequency + pl[unwrap.Choppers] = choppers + pl[unwrap.Ltotal] = distance pl[unwrap.PulseStride] = 2 + pl[unwrap.PulseStrideOffset] = 1 # Start the stride at the second pulse - one_pulse = beamline.source.data['pulse', 0] - pl[unwrap.SourceTimeRange] = ( - one_pulse.coords['time'].min(), - one_pulse.coords['time'].max(), - ) - pl[unwrap.SourceWavelengthRange] = ( - one_pulse.coords['wavelength'].min(), - one_pulse.coords['wavelength'].max(), - ) + tofs = pl.compute(unwrap.TofData) - pl[unwrap.Choppers] = choppers - pl[unwrap.Ltotal] = distance - result = pl.compute(unwrap.TofData) + # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - ref_wav = ref.transform_coords('wavelength', graph=graph).bins.concat().value - result.coords['Ltotal'] = distance - result_wav = result.transform_coords('wavelength', graph=graph).bins.concat().value - - assert sc.allclose( - result_wav.coords['wavelength'], - ref_wav.coords['wavelength'], - rtol=sc.scalar(1e-02), + wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value + ref = ref.bins.concat().value + + diff = abs( + (wavs.coords['wavelength'] - ref.coords['wavelength']) + / ref.coords['wavelength'] ) + # All errors should be small + assert np.nanpercentile(diff.values, 100) < 0.01 -def test_pulse_skipping_unwrap_with_half_of_first_frame_missing() -> None: - distance = sc.scalar(50.0, unit='m') - choppers = fakes.psc_choppers.copy() +def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing() -> None: + distance = sc.scalar(100.0, unit='m') + choppers = fakes.psc_disk_choppers.copy() choppers['pulse_skipping'] = fakes.pulse_skipping - # We use the ESS fake here because the fake beamline does not support choppers - # rotating at 7 Hz. beamline = fakes.FakeBeamlineEss( choppers=choppers, - monitors={'monitor': distance}, + monitors={'detector': distance}, run_length=sc.scalar(1.0, unit='s'), events_per_pulse=100_000, ) - mon, ref = beamline.get_monitor('monitor') + mon, ref = beamline.get_monitor('detector') pl = sl.Pipeline(unwrap.providers(), params=unwrap.params()) + pl[unwrap.Facility] = 'ess' pl[unwrap.RawData] = mon[1:].copy() # Skip first pulse = half of the first frame - pl[unwrap.PulsePeriod] = 1.0 / beamline.source.frequency + pl[unwrap.Choppers] = choppers + pl[unwrap.Ltotal] = distance pl[unwrap.PulseStride] = 2 pl[unwrap.PulseStrideOffset] = 1 # Start the stride at the second pulse - one_pulse = beamline.source.data['pulse', 0] - pl[unwrap.SourceTimeRange] = ( - one_pulse.coords['time'].min(), - one_pulse.coords['time'].max(), - ) - pl[unwrap.SourceWavelengthRange] = ( - one_pulse.coords['wavelength'].min(), - one_pulse.coords['wavelength'].max(), - ) + tofs = pl.compute(unwrap.TofData) - pl[unwrap.Choppers] = choppers - pl[unwrap.Ltotal] = distance - result = pl.compute(unwrap.TofData) + # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - ref_wav = ( - ref[1:].copy().transform_coords('wavelength', graph=graph).bins.concat().value - ) - result.coords['Ltotal'] = distance - result_wav = result.transform_coords('wavelength', graph=graph).bins.concat().value + wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value + ref = ref[1:].copy().bins.concat().value - assert sc.allclose( - result_wav.coords['wavelength'], - ref_wav.coords['wavelength'], - rtol=sc.scalar(1e-02), + diff = abs( + (wavs.coords['wavelength'] - ref.coords['wavelength']) + / ref.coords['wavelength'] ) + # All errors should be small + assert np.nanpercentile(diff.values, 100) < 0.01 -@pytest.mark.parametrize('dist', [44.0, 47.0]) -def test_pulse_skipping_unwrap_histogram_mode(dist) -> None: - distance = sc.scalar(dist, unit='m') - choppers = fakes.psc_choppers.copy() +def test_pulse_skipping_unwrap_histogram_mode() -> None: + distance = sc.scalar(100.0, unit='m') + choppers = fakes.psc_disk_choppers.copy() choppers['pulse_skipping'] = fakes.pulse_skipping - # We use the ESS fake here because the fake beamline does not support choppers - # rotating at 7 Hz. beamline = fakes.FakeBeamlineEss( choppers=choppers, - monitors={'monitor': distance}, + monitors={'detector': distance}, run_length=sc.scalar(1.0, unit='s'), events_per_pulse=100_000, ) - mon, ref = beamline.get_monitor('monitor') - + mon, ref = beamline.get_monitor('detector') mon = ( mon.hist( event_time_offset=sc.linspace( @@ -415,36 +270,16 @@ def test_pulse_skipping_unwrap_histogram_mode(dist) -> None: pl = sl.Pipeline( (*unwrap.providers(), unwrap.re_histogram_tof_data), params=unwrap.params() ) + pl[unwrap.Facility] = 'ess' pl[unwrap.RawData] = mon - pl[unwrap.PulsePeriod] = 1.0 / beamline.source.frequency - pl[unwrap.PulseStride] = 2 - - one_pulse = beamline.source.data['pulse', 0] - pl[unwrap.SourceTimeRange] = ( - one_pulse.coords['time'].min(), - one_pulse.coords['time'].max(), - ) - pl[unwrap.SourceWavelengthRange] = ( - one_pulse.coords['wavelength'].min(), - one_pulse.coords['wavelength'].max(), - ) - - pl[unwrap.Choppers] = choppers + pl[unwrap.Choppers] = fakes.psc_disk_choppers pl[unwrap.Ltotal] = distance - - result = pl.compute(unwrap.ReHistogrammedTofData) + pl[unwrap.PulseStride] = 2 + tofs = pl.compute(unwrap.ReHistogrammedTofData) graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - result.coords['Ltotal'] = distance - result_wav = result.transform_coords('wavelength', graph=graph) - ref_wav = ( - ref.transform_coords('wavelength', graph=graph) - .bins.concat() - .value.hist(wavelength=result_wav.coords['wavelength']) - ) - - # In this case, we used the ESS pulse. The counts on the edges of the frame are low, - # so relative differences can be large. Instead of a plain relative difference, we - # use the maximum of the reference data as the denominator. - diff = (result_wav - ref_wav) / ref_wav.max() - # Note: very conservative threshold to avoid making the test flaky. - assert sc.abs(diff).data.max() < sc.scalar(0.5) + wavs = tofs.transform_coords('wavelength', graph=graph) + ref = ref.bins.concat().value.hist(wavelength=wavs.coords['wavelength']) + # We divide by the maximum to avoid large relative differences at the edges of the + # frames where the counts are low. + diff = (wavs - ref) / ref.max() + assert np.nanpercentile(diff.values, 96.0) < 0.3 diff --git a/tests/tof/wfm_dream_test.py b/tests/tof/wfm_test.py similarity index 69% rename from tests/tof/wfm_dream_test.py rename to tests/tof/wfm_test.py index 0aa7932fe..1a6719ab4 100644 --- a/tests/tof/wfm_dream_test.py +++ b/tests/tof/wfm_test.py @@ -10,13 +10,13 @@ from scippneutron.chopper import DiskChopper from scippneutron.conversion.graph.beamline import beamline from scippneutron.conversion.graph.tof import elastic -from scippneutron.tof import chopper_cascade, fakes, unwrap +from scippneutron.tof import fakes, unwrap sl = pytest.importorskip('sciline') @pytest.fixture -def disk_choppers(): +def dream_disk_choppers(): psc1 = DiskChopper( frequency=sc.scalar(14.0, unit="Hz"), beam_position=sc.scalar(0.0, unit="deg"), @@ -92,15 +92,15 @@ def disk_choppers(): @pytest.fixture -def overlap_choppers(disk_choppers): - out = disk_choppers.copy() +def dream_choppers_with_frame_overlap(dream_disk_choppers): + out = dream_disk_choppers.copy() out['bcc'] = DiskChopper( frequency=sc.scalar(112.0, unit="Hz"), beam_position=sc.scalar(0.0, unit="deg"), phase=sc.scalar(240 - 180, unit="deg"), axle_position=sc.vector(value=[0, 0, 9.78], unit="m"), slit_begin=sc.array(dims=["cutout"], values=[-36.875, 143.125], unit="deg"), - slit_end=sc.array(dims=["cutout"], values=[46.875, 216.875], unit="deg"), + slit_end=sc.array(dims=["cutout"], values=[56.875, 216.875], unit="deg"), slit_height=sc.scalar(10.0, unit="cm"), radius=sc.scalar(30.0, unit="cm"), ) @@ -122,14 +122,9 @@ def overlap_choppers(disk_choppers): ) @pytest.mark.parametrize("time_offset_unit", ['s', 'ms', 'us', 'ns']) @pytest.mark.parametrize("distance_unit", ['m', 'mm']) -def test_dream_wfm(disk_choppers, npulses, ltotal, time_offset_unit, distance_unit): - choppers = { - key: chopper_cascade.Chopper.from_disk_chopper( - chop, pulse_frequency=sc.scalar(14.0, unit="Hz"), npulses=npulses - ) - for key, chop in disk_choppers.items() - } - +def test_dream_wfm( + dream_disk_choppers, npulses, ltotal, time_offset_unit, distance_unit +): monitors = { f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to='detector')) } @@ -140,7 +135,7 @@ def test_dream_wfm(disk_choppers, npulses, ltotal, time_offset_unit, distance_un ) birth_times = sc.full(sizes=wavelengths.sizes, value=1.5, unit='ms') ess_beamline = fakes.FakeBeamlineEss( - choppers=choppers, + choppers=dream_disk_choppers, monitors=monitors, run_length=sc.scalar(1 / 14, unit="s") * npulses, events_per_pulse=len(wavelengths), @@ -180,23 +175,11 @@ def test_dream_wfm(disk_choppers, npulses, ltotal, time_offset_unit, distance_un # Set up the workflow workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params()) - workflow[unwrap.PulsePeriod] = sc.reciprocal(ess_beamline.source.frequency) - - # Define the extent of the pulse that contains the 6 neutrons in time and wavelength - # Note that we make a larger encompassing pulse to ensure that the frame bounds are - # computed correctly - workflow[unwrap.SourceTimeRange] = ( - sc.scalar(0.0, unit='ms'), - sc.scalar(4.9, unit='ms'), - ) - workflow[unwrap.SourceWavelengthRange] = ( - sc.scalar(0.2, unit='angstrom'), - sc.scalar(16.0, unit='angstrom'), - ) - - workflow[unwrap.Choppers] = choppers - workflow[unwrap.Ltotal] = raw_data.coords['Ltotal'] + workflow[unwrap.Facility] = 'ess' workflow[unwrap.RawData] = raw_data + workflow[unwrap.Choppers] = dream_disk_choppers + workflow[unwrap.Ltotal] = raw_data.coords['Ltotal'] + workflow[unwrap.NumberOfNeutrons] = 100_000 # Compute time-of-flight tofs = workflow.compute(unwrap.TofData) @@ -234,15 +217,8 @@ def test_dream_wfm(disk_choppers, npulses, ltotal, time_offset_unit, distance_un @pytest.mark.parametrize("time_offset_unit", ['s', 'ms', 'us', 'ns']) @pytest.mark.parametrize("distance_unit", ['m', 'mm']) def test_dream_wfm_with_subframe_time_overlap( - overlap_choppers, npulses, ltotal, time_offset_unit, distance_unit + dream_choppers_with_frame_overlap, npulses, ltotal, time_offset_unit, distance_unit ): - choppers = { - key: chopper_cascade.Chopper.from_disk_chopper( - chop, pulse_frequency=sc.scalar(14.0, unit="Hz"), npulses=npulses - ) - for key, chop in overlap_choppers.items() - } - monitors = { f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to='detector')) } @@ -252,14 +228,14 @@ def test_dream_wfm_with_subframe_time_overlap( birth_times = [1.5] * len(wavelengths) # Add overlap neutrons - birth_times.extend([0.0, 3.1]) - wavelengths.extend([2.7, 2.5]) + birth_times.extend([0.0, 3.3]) + wavelengths.extend([2.6, 2.4]) wavelengths = sc.array(dims=['event'], values=wavelengths, unit='angstrom') birth_times = sc.array(dims=['event'], values=birth_times, unit='ms') ess_beamline = fakes.FakeBeamlineEss( - choppers=choppers, + choppers=dream_choppers_with_frame_overlap, monitors=monitors, run_length=sc.scalar(1 / 14, unit="s") * npulses, events_per_pulse=len(wavelengths), @@ -299,23 +275,15 @@ def test_dream_wfm_with_subframe_time_overlap( # Set up the workflow workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params()) - workflow[unwrap.PulsePeriod] = sc.reciprocal(ess_beamline.source.frequency) - - # Define the extent of the pulse that contains the 6 neutrons in time and wavelength - # Note that we make a larger encompassing pulse to ensure that the frame bounds are - # computed correctly - workflow[unwrap.SourceTimeRange] = ( - sc.scalar(0.0, unit='ms'), - sc.scalar(4.9, unit='ms'), - ) - workflow[unwrap.SourceWavelengthRange] = ( - sc.scalar(0.2, unit='angstrom'), - sc.scalar(16.0, unit='angstrom'), - ) - - workflow[unwrap.Choppers] = choppers - workflow[unwrap.Ltotal] = raw_data.coords['Ltotal'] + workflow[unwrap.Facility] = 'ess' workflow[unwrap.RawData] = raw_data + workflow[unwrap.Choppers] = dream_choppers_with_frame_overlap + workflow[unwrap.Ltotal] = raw_data.coords['Ltotal'] + workflow[unwrap.LookupTableVarianceThreshold] = 1.0e-3 + workflow[unwrap.NumberOfNeutrons] = 100_000 + + # Make sure the lookup table has a mask + assert len(workflow.compute(unwrap.MaskedTimeOfFlightLookupTable).masks) > 0 # Compute time-of-flight tofs = workflow.compute(unwrap.TofData) @@ -327,15 +295,9 @@ def test_dream_wfm_with_subframe_time_overlap( # Compare the computed wavelengths to the true wavelengths for i in range(npulses): - result_wav = wav_wfm['pulse', i].flatten(to='detector') - result_tof = tofs['pulse', i].flatten(to='detector') - for j in range(len(result_wav)): - computed_tofs = result_tof[j].values.coords["tof"] - # The two neutrons in the overlap region should have NaN tofs - assert sc.isnan(computed_tofs[-2]) - assert sc.isnan(computed_tofs[-1]) - - computed_wavelengths = result_wav[j].values.coords["wavelength"] + result = wav_wfm['pulse', i].flatten(to='detector') + for j in range(len(result)): + computed_wavelengths = result[j].values.coords["wavelength"] assert sc.allclose( computed_wavelengths[:-2], true_wavelengths['pulse', i][:-2], @@ -344,3 +306,95 @@ def test_dream_wfm_with_subframe_time_overlap( # The two neutrons in the overlap region should have NaN wavelengths assert sc.isnan(computed_wavelengths[-2]) assert sc.isnan(computed_wavelengths[-1]) + + +@pytest.mark.parametrize("npulses", [1, 2]) +@pytest.mark.parametrize( + "ltotal", + [ + sc.array(dims=['detector_number'], values=[26.0], unit='m'), + sc.array(dims=['detector_number'], values=[26.0, 25.5], unit='m'), + sc.array( + dims=['y', 'x'], values=[[26.0, 25.1, 26.33], [25.9, 26.0, 25.7]], unit='m' + ), + ], +) +@pytest.mark.parametrize("time_offset_unit", ['s', 'ms', 'us', 'ns']) +@pytest.mark.parametrize("distance_unit", ['m', 'mm']) +def test_v20_compute_wavelengths_from_wfm( + npulses, ltotal, time_offset_unit, distance_unit +): + monitors = { + f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to='detector')) + } + + # Create some neutron events + wavelengths = sc.array( + dims=['event'], values=[2.75, 4.2, 5.4, 6.5, 7.6, 8.75], unit='angstrom' + ) + birth_times = sc.full(sizes=wavelengths.sizes, value=1.5, unit='ms') + ess_beamline = fakes.FakeBeamlineEss( + choppers=fakes.wfm_disk_choppers, + monitors=monitors, + run_length=sc.scalar(1 / 14, unit="s") * npulses, + events_per_pulse=len(wavelengths), + source=partial( + tof_pkg.Source.from_neutrons, + birth_times=birth_times, + wavelengths=wavelengths, + frequency=sc.scalar(14.0, unit="Hz"), + ), + ) + + # Save the true wavelengths for later + true_wavelengths = ess_beamline.source.data.coords["wavelength"] + + raw_data = sc.concat( + [ess_beamline.get_monitor(key)[0] for key in monitors.keys()], + dim='detector', + ).fold(dim='detector', sizes=ltotal.sizes) + + # Convert the time offset to the unit requested by the test + raw_data.bins.coords["event_time_offset"] = raw_data.bins.coords[ + "event_time_offset" + ].to(unit=time_offset_unit, copy=False) + + raw_data.coords['Ltotal'] = ltotal.to(unit=distance_unit, copy=False) + + # Verify that all 6 neutrons made it through the chopper cascade + assert sc.identical( + raw_data.bins.concat('pulse').hist().data, + sc.array( + dims=['detector'], + values=[len(wavelengths) * npulses] * len(monitors), + unit="counts", + dtype='float64', + ).fold(dim='detector', sizes=ltotal.sizes), + ) + + # Set up the workflow + workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params()) + workflow[unwrap.Facility] = 'ess' + workflow[unwrap.RawData] = raw_data + workflow[unwrap.Choppers] = fakes.wfm_disk_choppers + workflow[unwrap.Ltotal] = raw_data.coords['Ltotal'] + workflow[unwrap.NumberOfNeutrons] = 100_000 + + # Compute time-of-flight + tofs = workflow.compute(unwrap.TofData) + assert {dim: tofs.sizes[dim] for dim in ltotal.sizes} == ltotal.sizes + + # Convert to wavelength + graph = {**beamline(scatter=False), **elastic("tof")} + wav_wfm = tofs.transform_coords("wavelength", graph=graph) + + # Compare the computed wavelengths to the true wavelengths + for i in range(npulses): + result = wav_wfm['pulse', i].flatten(to='detector') + for j in range(len(result)): + computed_wavelengths = result[j].values.coords["wavelength"] + assert sc.allclose( + computed_wavelengths, + true_wavelengths['pulse', i], + rtol=sc.scalar(1e-02), + ) diff --git a/tests/tof/wfm_v20_test.py b/tests/tof/wfm_v20_test.py deleted file mode 100644 index ac2cf7cc8..000000000 --- a/tests/tof/wfm_v20_test.py +++ /dev/null @@ -1,208 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) - -from functools import partial - -import numpy as np -import pytest -import scipp as sc -import tof as tof_pkg - -from scippneutron.chopper import DiskChopper -from scippneutron.conversion.graph.beamline import beamline -from scippneutron.conversion.graph.tof import elastic -from scippneutron.tof import chopper_cascade, fakes, unwrap - -sl = pytest.importorskip('sciline') - - -@pytest.fixture -def disk_choppers(): - wfm1 = DiskChopper( - frequency=sc.scalar(-70.0, unit="Hz"), - beam_position=sc.scalar(0.0, unit="deg"), - phase=sc.scalar(-47.10, unit="deg"), - axle_position=sc.vector(value=[0, 0, 6.6], unit="m"), - slit_begin=sc.array( - dims=["cutout"], - values=np.array([83.71, 140.49, 193.26, 242.32, 287.91, 330.3]) + 15.0, - unit="deg", - ), - slit_end=sc.array( - dims=["cutout"], - values=np.array([94.7, 155.79, 212.56, 265.33, 314.37, 360.0]) + 15.0, - unit="deg", - ), - slit_height=sc.scalar(10.0, unit="cm"), - radius=sc.scalar(30.0, unit="cm"), - ) - - wfm2 = DiskChopper( - frequency=sc.scalar(-70.0, unit="Hz"), - beam_position=sc.scalar(0.0, unit="deg"), - phase=sc.scalar(-76.76, unit="deg"), - axle_position=sc.vector(value=[0, 0, 7.1], unit="m"), - slit_begin=sc.array( - dims=["cutout"], - values=np.array([65.04, 126.1, 182.88, 235.67, 284.73, 330.32]) + 15.0, - unit="deg", - ), - slit_end=sc.array( - dims=["cutout"], - values=np.array([76.03, 141.4, 202.18, 254.97, 307.74, 360.0]) + 15.0, - unit="deg", - ), - slit_height=sc.scalar(10.0, unit="cm"), - radius=sc.scalar(30.0, unit="cm"), - ) - - foc1 = DiskChopper( - frequency=sc.scalar(-56.0, unit="Hz"), - beam_position=sc.scalar(0.0, unit="deg"), - phase=sc.scalar(-62.40, unit="deg"), - axle_position=sc.vector(value=[0, 0, 8.8], unit="m"), - slit_begin=sc.array( - dims=["cutout"], - values=np.array([74.6, 139.6, 194.3, 245.3, 294.8, 347.2]), - unit="deg", - ), - slit_end=sc.array( - dims=["cutout"], - values=np.array([95.2, 162.8, 216.1, 263.1, 310.5, 371.6]), - unit="deg", - ), - slit_height=sc.scalar(10.0, unit="cm"), - radius=sc.scalar(30.0, unit="cm"), - ) - - foc2 = DiskChopper( - frequency=sc.scalar(-28.0, unit="Hz"), - beam_position=sc.scalar(0.0, unit="deg"), - phase=sc.scalar(-12.27, unit="deg"), - axle_position=sc.vector(value=[0, 0, 15.9], unit="m"), - slit_begin=sc.array( - dims=["cutout"], - values=np.array([98.0, 154.0, 206.8, 255.0, 299.0, 344.65]), - unit="deg", - ), - slit_end=sc.array( - dims=["cutout"], - values=np.array([134.6, 190.06, 237.01, 280.88, 323.56, 373.76]), - unit="deg", - ), - slit_height=sc.scalar(10.0, unit="cm"), - radius=sc.scalar(30.0, unit="cm"), - ) - - return {"wfm1": wfm1, "wfm2": wfm2, "foc1": foc1, "foc2": foc2} - - -@pytest.mark.parametrize("npulses", [1, 2]) -@pytest.mark.parametrize( - "ltotal", - [ - sc.array(dims=['detector_number'], values=[26.0], unit='m'), - sc.array(dims=['detector_number'], values=[26.0, 25.5], unit='m'), - sc.array( - dims=['y', 'x'], values=[[26.0, 25.1, 26.33], [25.9, 26.0, 25.7]], unit='m' - ), - ], -) -@pytest.mark.parametrize("time_offset_unit", ['s', 'ms', 'us', 'ns']) -@pytest.mark.parametrize("distance_unit", ['m', 'mm']) -def test_v20_compute_wavelengths_from_wfm( - disk_choppers, npulses, ltotal, time_offset_unit, distance_unit -): - choppers = { - key: chopper_cascade.Chopper.from_disk_chopper( - chop, pulse_frequency=sc.scalar(14.0, unit="Hz"), npulses=npulses - ) - for key, chop in disk_choppers.items() - } - - monitors = { - f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to='detector')) - } - - # Create some neutron events - wavelengths = sc.array( - dims=['event'], values=[2.75, 4.2, 5.4, 6.5, 7.6, 8.75], unit='angstrom' - ) - birth_times = sc.full(sizes=wavelengths.sizes, value=1.5, unit='ms') - ess_beamline = fakes.FakeBeamlineEss( - choppers=choppers, - monitors=monitors, - run_length=sc.scalar(1 / 14, unit="s") * npulses, - events_per_pulse=len(wavelengths), - source=partial( - tof_pkg.Source.from_neutrons, - birth_times=birth_times, - wavelengths=wavelengths, - frequency=sc.scalar(14.0, unit="Hz"), - ), - ) - - # Save the true wavelengths for later - true_wavelengths = ess_beamline.source.data.coords["wavelength"] - - raw_data = sc.concat( - [ess_beamline.get_monitor(key)[0] for key in monitors.keys()], - dim='detector', - ).fold(dim='detector', sizes=ltotal.sizes) - - # Convert the time offset to the unit requested by the test - raw_data.bins.coords["event_time_offset"] = raw_data.bins.coords[ - "event_time_offset" - ].to(unit=time_offset_unit, copy=False) - - raw_data.coords['Ltotal'] = ltotal.to(unit=distance_unit, copy=False) - - # Verify that all 6 neutrons made it through the chopper cascade - assert sc.identical( - raw_data.bins.concat('pulse').hist().data, - sc.array( - dims=['detector'], - values=[len(wavelengths) * npulses] * len(monitors), - unit="counts", - dtype='float64', - ).fold(dim='detector', sizes=ltotal.sizes), - ) - - # Set up the workflow - workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params()) - workflow[unwrap.PulsePeriod] = sc.reciprocal(ess_beamline.source.frequency) - - # Define the extent of the pulse that contains the 6 neutrons in time and wavelength - # Note that we make a larger encompassing pulse to ensure that the frame bounds are - # computed correctly - workflow[unwrap.SourceTimeRange] = ( - sc.scalar(0.0, unit='ms'), - sc.scalar(3.4, unit='ms'), - ) - workflow[unwrap.SourceWavelengthRange] = ( - sc.scalar(0.2, unit='angstrom'), - sc.scalar(10.0, unit='angstrom'), - ) - - workflow[unwrap.Choppers] = choppers - workflow[unwrap.Ltotal] = raw_data.coords['Ltotal'] - workflow[unwrap.RawData] = raw_data - - # Compute time-of-flight - tofs = workflow.compute(unwrap.TofData) - assert {dim: tofs.sizes[dim] for dim in ltotal.sizes} == ltotal.sizes - - # Convert to wavelength - graph = {**beamline(scatter=False), **elastic("tof")} - wav_wfm = tofs.transform_coords("wavelength", graph=graph) - - # Compare the computed wavelengths to the true wavelengths - for i in range(npulses): - result = wav_wfm['pulse', i].flatten(to='detector') - for j in range(len(result)): - computed_wavelengths = result[j].values.coords["wavelength"] - assert sc.allclose( - computed_wavelengths, - true_wavelengths['pulse', i], - rtol=sc.scalar(1e-02), - )