diff --git a/docs/user-guide/tof/dream.ipynb b/docs/user-guide/tof/dream.ipynb index 6209e957..9bcd8dda 100644 --- a/docs/user-guide/tof/dream.ipynb +++ b/docs/user-guide/tof/dream.ipynb @@ -239,7 +239,7 @@ "raw_data = ess_beamline.get_monitor(\"detector\")[0]\n", "\n", "# Visualize\n", - "raw_data.hist(event_time_offset=300).sum(\"pulse\").plot()" + "raw_data.hist(event_time_offset=300).squeeze().plot()" ] }, { @@ -290,8 +290,10 @@ " time_of_flight.providers(), params=time_of_flight.default_parameters()\n", ")\n", "workflow[time_of_flight.RawData] = raw_data\n", - "workflow[time_of_flight.LtotalRange] = (sc.scalar(75.5, unit='m'),\n", - " sc.scalar(78.0, unit='m'))\n", + "workflow[time_of_flight.LtotalRange] = (\n", + " sc.scalar(75.5, unit=\"m\"),\n", + " sc.scalar(78.0, unit=\"m\"),\n", + ")\n", "\n", "workflow.visualize(time_of_flight.TofData)" ] @@ -314,8 +316,7 @@ "outputs": [], "source": [ "workflow[time_of_flight.SimulationResults] = time_of_flight.simulate_beamline(\n", - " choppers=disk_choppers,\n", - " neutrons=2_000_000\n", + " choppers=disk_choppers, neutrons=2_000_000\n", ")" ] }, @@ -343,17 +344,24 @@ "outputs": [], "source": [ "sim = workflow.compute(time_of_flight.SimulationResults)\n", - "# Compute time-of-arrival at the detector\n", - "tarrival = sim.time_of_arrival + ((Ltotal - sim.distance) / sim.speed).to(unit=\"us\")\n", - "# Compute time-of-flight at the detector\n", - "tflight = (Ltotal / sim.speed).to(unit=\"us\")\n", - "\n", - "events = sc.DataArray(\n", - " data=sim.weight,\n", - " coords={\"wavelength\": sim.wavelength, \"toa\": tarrival, \"tof\": tflight},\n", - ")\n", - "fig1 = events.hist(wavelength=300, toa=300).plot(norm=\"log\")\n", - "fig2 = events.hist(tof=300, toa=300).plot(norm=\"log\")\n", + "\n", + "\n", + "def to_event_time_offset(sim):\n", + " # Compute event_time_offset at the detector\n", + " eto = (\n", + " sim.time_of_arrival + ((Ltotal - sim.distance) / sim.speed).to(unit=\"us\")\n", + " ) % sc.scalar(1e6 / 14.0, unit=\"us\")\n", + " # Compute time-of-flight at the detector\n", + " tof = (Ltotal / sim.speed).to(unit=\"us\")\n", + " return sc.DataArray(\n", + " data=sim.weight,\n", + " coords={\"wavelength\": sim.wavelength, \"event_time_offset\": eto, \"tof\": tof},\n", + " )\n", + "\n", + "\n", + "events = to_event_time_offset(sim)\n", + "fig1 = events.hist(wavelength=300, event_time_offset=300).plot(norm=\"log\")\n", + "fig2 = events.hist(tof=300, event_time_offset=300).plot(norm=\"log\")\n", "fig1 + fig2" ] }, @@ -374,10 +382,10 @@ "metadata": {}, "outputs": [], "source": [ - "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable)\n", + "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable).squeeze()\n", "\n", "# Overlay mean on the figure above\n", - "table[\"distance\", 13].plot(ax=fig2.ax, color=\"C1\", ls='-', marker=None)" + "table[\"distance\", 13].plot(ax=fig2.ax, color=\"C1\", ls=\"-\", marker=None)" ] }, { @@ -447,7 +455,8 @@ "# Define wavelength bin edges\n", "wavs = sc.linspace(\"wavelength\", 0.8, 4.6, 201, unit=\"angstrom\")\n", "\n", - "wav_wfm.hist(wavelength=wavs).sum(\"pulse\").plot()" + "histogrammed = wav_wfm.hist(wavelength=wavs).squeeze()\n", + "histogrammed.plot()" ] }, { @@ -473,7 +482,7 @@ "\n", "pp.plot(\n", " {\n", - " \"wfm\": wav_wfm.hist(wavelength=wavs).sum(\"pulse\"),\n", + " \"wfm\": histogrammed,\n", " \"ground_truth\": ground_truth.hist(wavelength=wavs),\n", " }\n", ")" @@ -530,16 +539,12 @@ "outputs": [], "source": [ "raw_data = sc.concat(\n", - " [ess_beamline.get_monitor(key)[0] for key in monitors.keys()],\n", + " [ess_beamline.get_monitor(key)[0].squeeze() for key in monitors.keys()],\n", " dim=\"detector_number\",\n", ")\n", "\n", "# Visualize\n", - "pp.plot(\n", - " sc.collapse(\n", - " raw_data.hist(event_time_offset=300).sum(\"pulse\"), keep=\"event_time_offset\"\n", - " )\n", - ")" + "pp.plot(sc.collapse(raw_data.hist(event_time_offset=300), keep=\"event_time_offset\"))" ] }, { @@ -655,22 +660,14 @@ "source": [ "# Update workflow\n", "workflow[time_of_flight.SimulationResults] = time_of_flight.simulate_beamline(\n", - " choppers=disk_choppers,\n", - " neutrons=2_000_000\n", + " choppers=disk_choppers, neutrons=2_000_000\n", ")\n", "workflow[time_of_flight.RawData] = ess_beamline.get_monitor(\"detector\")[0]\n", "\n", "sim = workflow.compute(time_of_flight.SimulationResults)\n", - "# Compute time-of-arrival at the detector\n", - "tarrival = sim.time_of_arrival + ((Ltotal - sim.distance) / sim.speed).to(unit=\"us\")\n", - "# Compute time-of-flight at the detector\n", - "tflight = (Ltotal / sim.speed).to(unit=\"us\")\n", - "\n", - "events = sc.DataArray(\n", - " data=sim.weight,\n", - " coords={\"wavelength\": sim.wavelength, \"toa\": tarrival, \"tof\": tflight},\n", - ")\n", - "events.hist(wavelength=300, toa=300).plot(norm=\"log\")" + "\n", + "events = to_event_time_offset(sim)\n", + "events.hist(wavelength=300, event_time_offset=300).plot(norm=\"log\")" ] }, { @@ -696,7 +693,7 @@ "metadata": {}, "outputs": [], "source": [ - "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable)\n", + "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable).squeeze()\n", "table.plot() / (sc.stddevs(table) / sc.values(table)).plot(norm=\"log\")" ] }, @@ -720,9 +717,9 @@ "metadata": {}, "outputs": [], "source": [ - "workflow[time_of_flight.LookupTableRelativeErrorThreshold] = 1.0e-2\n", + "workflow[time_of_flight.LookupTableRelativeErrorThreshold] = 0.01\n", "\n", - "workflow.compute(time_of_flight.MaskedTimeOfFlightLookupTable).plot()" + "workflow.compute(time_of_flight.TimeOfFlightLookupTable).squeeze().plot()" ] }, { @@ -757,7 +754,7 @@ "\n", "pp.plot(\n", " {\n", - " \"wfm\": wav_wfm.hist(wavelength=wavs).sum(\"pulse\"),\n", + " \"wfm\": wav_wfm.hist(wavelength=wavs).squeeze(),\n", " \"ground_truth\": ground_truth.hist(wavelength=wavs),\n", " }\n", ")" diff --git a/docs/user-guide/tof/frame-unwrapping.ipynb b/docs/user-guide/tof/frame-unwrapping.ipynb index aca44aae..7d97d691 100644 --- a/docs/user-guide/tof/frame-unwrapping.ipynb +++ b/docs/user-guide/tof/frame-unwrapping.ipynb @@ -88,22 +88,12 @@ "\n", "model = tof.Model(source=source, choppers=[chopper], detectors=detectors)\n", "results = model.run()\n", - "pl = results.plot(cmap=\"viridis_r\")\n", + "pl = results.plot()\n", "\n", - "for i in range(source.pulses):\n", + "for i in range(2 * source.pulses):\n", " pl.ax.axvline(\n", " i * (1.0 / source.frequency).to(unit=\"us\").value, color=\"k\", ls=\"dotted\"\n", - " )\n", - " x = [\n", - " results[det.name].toas.data[\"visible\"][f\"pulse:{i}\"].coords[\"toa\"].min().value\n", - " for det in detectors\n", - " ]\n", - " y = [det.distance.value for det in detectors]\n", - " pl.ax.plot(x, y, \"--o\", color=\"magenta\", lw=3)\n", - " if i == 0:\n", - " pl.ax.text(\n", - " x[2], y[2] * 1.05, \"pivot time\", va=\"bottom\", ha=\"right\", color=\"magenta\"\n", - " )" + " )" ] }, { @@ -142,14 +132,6 @@ " .hist(event_time_offset=200)\n", " .plot(title=f\"{det.name}={det.distance:c}\", color=f\"C{i}\")\n", " )\n", - " f = subplots[i // 2, i % 2]\n", - " xpiv = min(\n", - " da.coords[\"toa\"].min() % (1.0 / source.frequency).to(unit=\"us\")\n", - " for da in results[det.name].toas.data[\"visible\"].values()\n", - " ).value\n", - " f.ax.axvline(xpiv, ls=\"dashed\", color=\"magenta\", lw=2)\n", - " f.ax.text(xpiv, 20, \"pivot time\", rotation=90, color=\"magenta\")\n", - " f.canvas.draw()\n", "subplots" ] }, @@ -157,39 +139,23 @@ "cell_type": "markdown", "id": "7", "metadata": {}, - "source": [ - "### Pivot time\n", - "\n", - "To compute the time-of-flight for a neutron, we need to identify which source pulse it originated from.\n", - "\n", - "In the first figure, the pink lines represent the earliest recorded arrival time at each detector:\n", - "we know that within a given frame at a selected detector,\n", - "any neutron recorded at a time earlier than this 'pivot' time must from from a previous pulse.\n", - "\n", - "The position of the pink lines is repeated in the second figure (above).\n", - "We can use this knowledge to unwrap the frames and compute the absolute time-of-arrival of the neutrons at the detectors." - ] - }, - { - "cell_type": "markdown", - "id": "8", - "metadata": {}, "source": [ "### Computing time-of-flight\n", "\n", - "The pivot time and the resulting offsets can be computed from the properties of the source pulse and the chopper cascade.\n", - "\n", "We describe in this section the workflow that computes time-of-flight,\n", "given `event_time_zero` and `event_time_offset` for neutron events,\n", "as well as the properties of the source pulse and the choppers in the beamline.\n", "\n", + "In short, we use a lookup table which can predict the wavelength (or time-of-flight) of the neutrons,\n", + "according to their `event_time_offset`.\n", + "\n", "The workflow can be visualized as follows:" ] }, { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -221,95 +187,16 @@ }, { "cell_type": "markdown", - "id": "10", - "metadata": {}, - "source": [ - "#### Unwrapped neutron time-of-arrival\n", - "\n", - "The first step that is computed in the workflow is the unwrapped detector arrival time of each neutron.\n", - "This is essentially just `event_time_offset + event_time_zero`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": {}, - "outputs": [], - "source": [ - "da = workflow.compute(time_of_flight.UnwrappedTimeOfArrival)[\n", - " \"detector_number\", 2\n", - "] # Look at a single detector\n", - "da.bins.concat().value.hist(time_of_arrival=300).plot()" - ] - }, - { - "cell_type": "markdown", - "id": "12", - "metadata": {}, - "source": [ - "#### Unwrapped neutron time-of-arrival minus pivot time\n", - "\n", - "The next step is to subtract the pivot time to the unwrapped arrival times,\n", - "to align the times so that they start at zero.\n", - "\n", - "This allows us to perform a computationally cheap modulo operation on the times below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": {}, - "outputs": [], - "source": [ - "da = workflow.compute(time_of_flight.UnwrappedTimeOfArrivalMinusPivotTime)[\n", - " \"detector_number\", 2\n", - "]\n", - "f = da.bins.concat().value.hist(time_of_arrival=300).plot()\n", - "for i in range(source.pulses):\n", - " f.ax.axvline(\n", - " i * (1.0 / source.frequency).to(unit=\"us\").value, color=\"k\", ls=\"dotted\"\n", - " )\n", - "f" - ] - }, - { - "cell_type": "markdown", - "id": "14", - "metadata": {}, - "source": [ - "The vertical dotted lines here represent the frame period.\n", - "\n", - "#### Unwrapped neutron time-of-arrival modulo the frame period\n", - "\n", - "We now wrap the arrival times with the frame period to obtain well formed (unbroken) set of events.\n", - "\n", - "We also re-add the pivot time offset we had subtracted earlier (to enable to modulo operation)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "da = workflow.compute(time_of_flight.FrameFoldedTimeOfArrival)[\"detector_number\", 2]\n", - "da.bins.concat().value.hist(time_of_arrival=200).plot()" - ] - }, - { - "cell_type": "markdown", - "id": "16", + "id": "9", "metadata": {}, "source": [ - "#### Create a lookup table\n", + "#### Create the lookup table\n", "\n", - "The chopper information is next used to construct a lookup table that provides an estimate of the real time-of-flight as a function of time-of-arrival.\n", + "The chopper information is used to construct a lookup table that provides an estimate of the real time-of-flight as a function of time-of-arrival.\n", "\n", "The `tof` module can be used to propagate a pulse of neutrons through the chopper system to the detectors,\n", "and predict the most likely neutron wavelength for a given time-of-arrival.\n", + "More advanced programs such as McStas can of course also be used for even better results.\n", "\n", "We typically have hundreds of thousands of pixels in an instrument,\n", "but it is actually not necessary to propagate the neutrons to 105 detectors.\n", @@ -332,17 +219,17 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "10", "metadata": {}, "outputs": [], "source": [ - "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable)\n", + "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable).squeeze()\n", "table.plot()" ] }, { "cell_type": "markdown", - "id": "18", + "id": "11", "metadata": {}, "source": [ "#### Computing time-of-flight from the lookup\n", @@ -353,19 +240,19 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "12", "metadata": {}, "outputs": [], "source": [ "tofs = workflow.compute(time_of_flight.TofData)\n", "\n", - "tof_hist = tofs.bins.concat(\"pulse\").hist(tof=sc.scalar(500.0, unit=\"us\"))\n", + "tof_hist = tofs.hist(tof=sc.scalar(500.0, unit=\"us\"))\n", "pp.plot({det.name: tof_hist[\"detector_number\", i] for i, det in enumerate(detectors)})" ] }, { "cell_type": "markdown", - "id": "20", + "id": "13", "metadata": {}, "source": [ "### Converting to wavelength\n", @@ -379,7 +266,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -393,11 +280,7 @@ "bins = sc.linspace(\"wavelength\", 6.0, 9.0, 101, unit=\"angstrom\")\n", "\n", "# Compute wavelengths\n", - "wav_hist = (\n", - " tofs.transform_coords(\"wavelength\", graph=graph)\n", - " .bins.concat(\"pulse\")\n", - " .hist(wavelength=bins)\n", - ")\n", + "wav_hist = tofs.transform_coords(\"wavelength\", graph=graph).hist(wavelength=bins)\n", "wavs = {det.name: wav_hist[\"detector_number\", i] for i, det in enumerate(detectors)}\n", "\n", "ground_truth = results[\"detector\"].data.flatten(to=\"event\")\n", @@ -411,7 +294,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "15", "metadata": {}, "source": [ "We see that all detectors agree on the wavelength spectrum,\n", @@ -432,7 +315,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -467,7 +350,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "17", "metadata": {}, "source": [ "### Computing time-of-flight\n", @@ -481,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -517,27 +400,27 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "19", "metadata": {}, "source": [ - "If we inspect the time-of-flight lookup table,\n", - "we can see that the time-of-arrival (toa) dimension now spans longer than the pulse period of 71 ms." + "The time-of-flight lookup table now has an extra `pulse` dimension for each pulse in the stride:" ] }, { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "20", "metadata": {}, "outputs": [], "source": [ "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable)\n", - "table.plot()" + "\n", + "table[\"pulse\", 0].plot(title=\"Pulse-0\") + table[\"pulse\", 1].plot(title=\"Pulse-1\")" ] }, { "cell_type": "markdown", - "id": "28", + "id": "21", "metadata": {}, "source": [ "The time-of-flight profiles are then:" @@ -546,19 +429,19 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "22", "metadata": {}, "outputs": [], "source": [ "tofs = workflow.compute(time_of_flight.TofData)\n", "\n", - "tof_hist = tofs.bins.concat(\"pulse\").hist(tof=sc.scalar(500.0, unit=\"us\"))\n", + "tof_hist = tofs.hist(tof=sc.scalar(500.0, unit=\"us\"))\n", "pp.plot({det.name: tof_hist[\"detector_number\", i] for i, det in enumerate(detectors)})" ] }, { "cell_type": "markdown", - "id": "30", + "id": "23", "metadata": {}, "source": [ "### Conversion to wavelength\n", @@ -569,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -577,11 +460,7 @@ "bins = sc.linspace(\"wavelength\", 1.0, 8.0, 401, unit=\"angstrom\")\n", "\n", "# Compute wavelengths\n", - "wav_hist = (\n", - " tofs.transform_coords(\"wavelength\", graph=graph)\n", - " .bins.concat(\"pulse\")\n", - " .hist(wavelength=bins)\n", - ")\n", + "wav_hist = tofs.transform_coords(\"wavelength\", graph=graph).hist(wavelength=bins)\n", "wavs = {det.name: wav_hist[\"detector_number\", i] for i, det in enumerate(detectors)}\n", "\n", "ground_truth = results[\"detector\"].data.flatten(to=\"event\")\n", diff --git a/docs/user-guide/tof/wfm.ipynb b/docs/user-guide/tof/wfm.ipynb index 05446d71..3c312141 100644 --- a/docs/user-guide/tof/wfm.ipynb +++ b/docs/user-guide/tof/wfm.ipynb @@ -259,7 +259,7 @@ "raw_data = ess_beamline.get_monitor(\"detector\")[0]\n", "\n", "# Visualize\n", - "raw_data.hist(event_time_offset=300).sum(\"pulse\").plot()" + "raw_data.hist(event_time_offset=300).squeeze().plot()" ] }, { @@ -394,7 +394,7 @@ "metadata": {}, "outputs": [], "source": [ - "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable)\n", + "table = workflow.compute(time_of_flight.TimeOfFlightLookupTable).squeeze()\n", "\n", "# Overlay mean on the figure above\n", "table[\"distance\", 1].plot(ax=fig2.ax, color=\"C1\", ls=\"-\", marker=None)\n", @@ -477,7 +477,8 @@ "# Define wavelength bin edges\n", "wavs = sc.linspace(\"wavelength\", 2, 10, 301, unit=\"angstrom\")\n", "\n", - "wav_wfm.hist(wavelength=wavs).sum(\"pulse\").plot()" + "histogrammed = wav_wfm.hist(wavelength=wavs).squeeze()\n", + "histogrammed.plot()" ] }, { @@ -503,7 +504,7 @@ "\n", "pp.plot(\n", " {\n", - " \"wfm\": wav_wfm.hist(wavelength=wavs).sum(\"pulse\"),\n", + " \"wfm\": histogrammed,\n", " \"ground_truth\": ground_truth.hist(wavelength=wavs),\n", " }\n", ")" diff --git a/requirements/base.txt b/requirements/base.txt index 070daf3c..beeeab62 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -11,7 +11,7 @@ cyclebane==24.10.0 # via sciline cycler==0.12.1 # via matplotlib -fonttools==4.55.4 +fonttools==4.56.0 # via matplotlib h5py==3.12.1 # via diff --git a/requirements/basetest.txt b/requirements/basetest.txt index 4408cd1f..471ae175 100644 --- a/requirements/basetest.txt +++ b/requirements/basetest.txt @@ -7,7 +7,7 @@ # asttokens==3.0.0 # via stack-data -certifi==2024.12.14 +certifi==2025.1.31 # via requests charset-normalizer==3.4.1 # via requests @@ -25,7 +25,7 @@ exceptiongroup==1.2.2 # pytest executing==2.2.0 # via stack-data -fonttools==4.55.4 +fonttools==4.56.0 # via matplotlib idna==3.10 # via requests @@ -33,7 +33,7 @@ importlib-resources==6.5.2 # via tof iniconfig==2.0.0 # via pytest -ipython==8.31.0 +ipython==8.32.0 # via ipywidgets ipywidgets==8.1.5 # via -r basetest.in @@ -98,7 +98,7 @@ six==1.17.0 # via python-dateutil stack-data==0.6.3 # via ipython -tof==25.1.2 +tof==25.2.0 # via -r basetest.in tomli==2.2.1 # via pytest diff --git a/requirements/ci.txt b/requirements/ci.txt index 9ffaa1db..10820857 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -7,7 +7,7 @@ # cachetools==5.5.1 # via tox -certifi==2024.12.14 +certifi==2025.1.31 # via requests chardet==5.2.0 # via tox diff --git a/requirements/dev.txt b/requirements/dev.txt index de09d3ae..f5ef3159 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -59,7 +59,7 @@ jsonschema[format-nongpl]==4.23.0 # jupyter-events # jupyterlab-server # nbformat -jupyter-events==0.11.0 +jupyter-events==0.12.0 # via jupyter-server jupyter-lsp==2.2.5 # via jupyterlab @@ -71,7 +71,7 @@ jupyter-server==2.15.0 # notebook-shim jupyter-server-terminals==0.5.3 # via jupyter-server -jupyterlab==4.3.4 +jupyterlab==4.3.5 # via -r dev.in jupyterlab-server==2.27.3 # via jupyterlab @@ -91,7 +91,7 @@ prometheus-client==0.21.1 # via jupyter-server pycparser==2.22 # via cffi -pydantic==2.10.5 +pydantic==2.10.6 # via copier pydantic-core==2.27.2 # via pydantic diff --git a/requirements/docs.txt b/requirements/docs.txt index bf65e28f..6dc1d93b 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -12,21 +12,21 @@ alabaster==1.0.0 # via sphinx asttokens==3.0.0 # via stack-data -attrs==24.3.0 +attrs==25.1.0 # via # jsonschema # referencing -babel==2.16.0 +babel==2.17.0 # via # pydata-sphinx-theme # sphinx -beautifulsoup4==4.12.3 +beautifulsoup4==4.13.3 # via # nbconvert # pydata-sphinx-theme bleach[css]==6.2.0 # via nbconvert -certifi==2024.12.14 +certifi==2025.1.31 # via requests charset-normalizer==3.4.1 # via requests @@ -62,7 +62,7 @@ importlib-resources==6.5.2 # via tof ipykernel==6.29.5 # via -r docs.in -ipython==8.31.0 +ipython==8.32.0 # via # -r docs.in # ipykernel @@ -112,13 +112,13 @@ mdit-py-plugins==0.4.2 # via myst-parser mdurl==0.1.2 # via markdown-it-py -mistune==3.1.0 +mistune==3.1.1 # via nbconvert myst-parser==4.0.0 # via -r docs.in nbclient==0.10.2 # via nbconvert -nbconvert==7.16.5 +nbconvert==7.16.6 # via nbsphinx nbformat==5.10.4 # via @@ -156,11 +156,11 @@ pygments==2.19.1 # sphinx pyyaml==6.0.2 # via myst-parser -pyzmq==26.2.0 +pyzmq==26.2.1 # via # ipykernel # jupyter-client -referencing==0.36.1 +referencing==0.36.2 # via # jsonschema # jsonschema-specifications @@ -205,7 +205,7 @@ stack-data==0.6.3 # via ipython tinycss2==1.4.0 # via bleach -tof==25.1.2 +tof==25.2.0 # via -r docs.in tomli==2.2.1 # via sphinx @@ -228,6 +228,7 @@ traitlets==5.14.3 # nbsphinx typing-extensions==4.12.2 # via + # beautifulsoup4 # ipython # mistune # pydata-sphinx-theme diff --git a/requirements/mypy.txt b/requirements/mypy.txt index d7a49e8a..61d88db1 100644 --- a/requirements/mypy.txt +++ b/requirements/mypy.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r test.txt -mypy==1.14.1 +mypy==1.15.0 # via -r mypy.in mypy-extensions==1.0.0 # via mypy diff --git a/requirements/nightly.txt b/requirements/nightly.txt index e4af6e81..210bf281 100644 --- a/requirements/nightly.txt +++ b/requirements/nightly.txt @@ -5,9 +5,11 @@ # # pip-compile-multi # +annotated-types==0.7.0 + # via pydantic asttokens==3.0.0 # via stack-data -certifi==2024.12.14 +certifi==2025.1.31 # via requests charset-normalizer==3.4.1 # via requests @@ -23,25 +25,31 @@ cycler==0.12.1 # via matplotlib decorator==5.1.1 # via ipython +dnspython==2.7.0 + # via email-validator +email-validator==2.2.0 + # via scippneutron exceptiongroup==1.2.2 # via # ipython # pytest executing==2.2.0 # via stack-data -fonttools==4.55.4 +fonttools==4.56.0 # via matplotlib h5py==3.12.1 # via # scippneutron # scippnexus idna==3.10 - # via requests + # via + # email-validator + # requests importlib-resources==6.5.2 # via tof iniconfig==2.0.0 # via pytest -ipython==8.31.0 +ipython==8.32.0 # via ipywidgets ipywidgets==8.1.5 # via -r nightly.in @@ -51,6 +59,8 @@ jupyterlab-widgets==3.0.13 # via ipywidgets kiwisolver==1.4.8 # via matplotlib +lazy-loader==0.4 + # via scippneutron matplotlib==3.10.0 # via # mpltoolbox @@ -72,6 +82,7 @@ numpy==2.2.2 # scipy packaging==24.2 # via + # lazy-loader # matplotlib # pooch # pytest @@ -97,6 +108,10 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data +pydantic==2.10.6 + # via scippneutron +pydantic-core==2.27.2 + # via pydantic pygments==2.19.1 # via ipython pyparsing==3.2.1 @@ -106,6 +121,7 @@ pytest==8.3.4 python-dateutil==2.9.0.post0 # via # matplotlib + # scippneutron # scippnexus requests==2.32.3 # via pooch @@ -144,7 +160,10 @@ traitlets==5.14.3 # ipywidgets # matplotlib-inline typing-extensions==4.12.2 - # via ipython + # via + # ipython + # pydantic + # pydantic-core urllib3==2.3.0 # via requests wcwidth==0.2.13 diff --git a/src/ess/reduce/time_of_flight/__init__.py b/src/ess/reduce/time_of_flight/__init__.py index 21650df0..baf03e63 100644 --- a/src/ess/reduce/time_of_flight/__init__.py +++ b/src/ess/reduce/time_of_flight/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) """ Utilities for computing real neutron time-of-flight from chopper settings and @@ -7,52 +7,39 @@ """ from .simulation import simulate_beamline -from .toa_to_tof import default_parameters, resample_tof_data, providers, TofWorkflow +from .toa_to_tof import default_parameters, resample_tof_data, providers from .to_events import to_events from .types import ( DistanceResolution, - FrameFoldedTimeOfArrival, - FramePeriod, LookupTableRelativeErrorThreshold, Ltotal, LtotalRange, - MaskedTimeOfFlightLookupTable, - PivotTimeAtDetector, PulsePeriod, PulseStride, PulseStrideOffset, RawData, ResampledTofData, SimulationResults, - TimeOfArrivalMinusPivotTimeModuloPeriod, TimeOfFlightLookupTable, + TimeResolution, TofData, - UnwrappedTimeOfArrival, - UnwrappedTimeOfArrivalMinusPivotTime, ) __all__ = [ "DistanceResolution", - "FrameFoldedTimeOfArrival", - "FramePeriod", "LookupTableRelativeErrorThreshold", "Ltotal", "LtotalRange", - "MaskedTimeOfFlightLookupTable", - "PivotTimeAtDetector", "PulsePeriod", "PulseStride", "PulseStrideOffset", "RawData", "ResampledTofData", "SimulationResults", - "TimeOfArrivalMinusPivotTimeModuloPeriod", "TimeOfFlightLookupTable", + "TimeResolution", "TofData", - "TofWorkflow", - "UnwrappedTimeOfArrival", - "UnwrappedTimeOfArrivalMinusPivotTime", "default_parameters", "providers", "resample_tof_data", diff --git a/src/ess/reduce/time_of_flight/fakes.py b/src/ess/reduce/time_of_flight/fakes.py index 5679b7d5..acb38241 100644 --- a/src/ess/reduce/time_of_flight/fakes.py +++ b/src/ess/reduce/time_of_flight/fakes.py @@ -21,6 +21,7 @@ def __init__( monitors: dict[str, sc.Variable], run_length: sc.Variable, events_per_pulse: int = 200000, + seed: int | None = None, source: Callable | None = None, ): import math @@ -35,7 +36,10 @@ def __init__( # Create a source if source is None: self.source = tof_pkg.Source( - facility="ess", neutrons=self.events_per_pulse, pulses=self.npulses + facility="ess", + neutrons=self.events_per_pulse, + pulses=self.npulses, + seed=seed, ) else: self.source = source(pulses=self.npulses) @@ -69,34 +73,10 @@ def __init__( self.model_result = self.model.run() def get_monitor(self, name: str) -> sc.DataGroup: - # Create some fake pulse time zero - start = sc.datetime("2024-01-01T12:00:00.000000") - period = sc.reciprocal(self.frequency) - - detector = self.model_result.detectors[name] - raw_data = detector.data.flatten(to="event") - # Select only the neutrons that make it to the detector + nx_event_data = self.model_result.to_nxevent_data(name) + raw_data = self.model_result.detectors[name].data.flatten(to="event") raw_data = raw_data[~raw_data.masks["blocked_by_others"]].copy() - raw_data.coords["Ltotal"] = detector.distance - - # Format the data in a way that resembles data loaded from NeXus - event_data = raw_data.copy(deep=False) - dt = period.to(unit="us") - event_time_zero = (dt * (event_data.coords["toa"] // dt)).to(dtype=int) + start - raw_data.coords["event_time_zero"] = event_time_zero - event_data.coords["event_time_zero"] = event_time_zero - event_data.coords["event_time_offset"] = ( - event_data.coords.pop("toa").to(unit="s") % period - ) - del event_data.coords["tof"] - del event_data.coords["speed"] - del event_data.coords["time"] - del event_data.coords["wavelength"] - - return ( - event_data.group("event_time_zero").rename_dims(event_time_zero="pulse"), - raw_data.group("event_time_zero").rename_dims(event_time_zero="pulse"), - ) + return nx_event_data, raw_data wfm1_chopper = DiskChopper( @@ -194,25 +174,6 @@ def get_monitor(self, name: str) -> sc.DataGroup: radius=sc.scalar(30.0, unit="cm"), ) -pulse_skipping = DiskChopper( - frequency=sc.scalar(-7.0, unit="Hz"), - beam_position=sc.scalar(0.0, unit="deg"), - phase=sc.scalar(0.0, unit="deg"), - axle_position=sc.vector(value=[0, 0, 30.0], unit="m"), - slit_begin=sc.array( - dims=["cutout"], - values=np.array([40.0]), - unit="deg", - ), - slit_end=sc.array( - dims=["cutout"], - values=np.array([140.0]), - unit="deg", - ), - slit_height=sc.scalar(10.0, unit="cm"), - radius=sc.scalar(30.0, unit="cm"), -) - def wfm_choppers(): return { @@ -238,3 +199,24 @@ def psc_choppers(): ) for name, ch in wfm_choppers().items() } + + +def pulse_skipping_chopper(): + return DiskChopper( + frequency=sc.scalar(-7.0, unit="Hz"), + beam_position=sc.scalar(0.0, unit="deg"), + phase=sc.scalar(0.0, unit="deg"), + axle_position=sc.vector(value=[0, 0, 30.0], unit="m"), + slit_begin=sc.array( + dims=["cutout"], + values=np.array([40.0]), + unit="deg", + ), + slit_end=sc.array( + dims=["cutout"], + values=np.array([140.0]), + unit="deg", + ), + slit_height=sc.scalar(10.0, unit="cm"), + radius=sc.scalar(30.0, unit="cm"), + ) diff --git a/src/ess/reduce/time_of_flight/simulation.py b/src/ess/reduce/time_of_flight/simulation.py index 7bdf79aa..5aa4d617 100644 --- a/src/ess/reduce/time_of_flight/simulation.py +++ b/src/ess/reduce/time_of_flight/simulation.py @@ -11,6 +11,7 @@ def simulate_beamline( choppers: Mapping[str, DiskChopper], neutrons: int = 1_000_000, + pulses: int = 1, seed: int | None = None, facility: str = 'ess', ) -> SimulationResults: @@ -26,6 +27,8 @@ def simulate_beamline( for more information. neutrons: Number of neutrons to simulate. + pulses: + Number of pulses to simulate. seed: Seed for the random number generator used in the simulation. facility: @@ -47,9 +50,9 @@ def simulate_beamline( ) for name, ch in choppers.items() ] - source = tof.Source(facility=facility, neutrons=neutrons, seed=seed) + source = tof.Source(facility=facility, neutrons=neutrons, pulses=pulses, seed=seed) if not tof_choppers: - events = source.data.squeeze() + events = source.data.squeeze().flatten(to='event') return SimulationResults( time_of_arrival=events.coords["time"], speed=events.coords["speed"], @@ -61,7 +64,7 @@ def simulate_beamline( results = model.run() # Find name of the furthest chopper in tof_choppers furthest_chopper = max(tof_choppers, key=lambda c: c.distance) - events = results[furthest_chopper.name].data.squeeze() + events = results[furthest_chopper.name].data.squeeze().flatten(to='event') events = events[ ~(events.masks["blocked_by_others"] | events.masks["blocked_by_me"]) ] diff --git a/src/ess/reduce/time_of_flight/to_events.py b/src/ess/reduce/time_of_flight/to_events.py index 8afef366..512e0ddc 100644 --- a/src/ess/reduce/time_of_flight/to_events.py +++ b/src/ess/reduce/time_of_flight/to_events.py @@ -34,15 +34,20 @@ def to_events( rng = np.random.default_rng() event_coords = {} edge_dims = [] - midp_dims = [] + midp_dims = set() + midp_coord_names = [] # Separate bin-edge and midpoints coords - for dim in da.dims: - if da.coords.is_edges(dim): - edge_dims.append(dim) + for name in da.coords: + dims = da.coords[name].dims + is_edges = False if not dims else da.coords.is_edges(name) + if is_edges: + if name in dims: + edge_dims.append(name) else: - midp_dims.append(dim) + midp_coord_names.append(name) + midp_dims.update(set(dims)) - edge_sizes = {dim: da.sizes[dim] for dim in edge_dims} + edge_sizes = {dim: da.sizes[da.coords[dim].dim] for dim in edge_dims} for dim in edge_dims: coord = da.coords[dim] left = sc.broadcast(coord[dim, :-1], sizes=edge_sizes).values @@ -102,5 +107,5 @@ def to_events( dims=[*edge_dims, event_dim], to=event_dim ) return new.assign_coords( - {dim: da.coords[dim].copy() for dim in midp_dims} + {dim: da.coords[dim].copy() for dim in midp_coord_names} ).assign_masks({key: mask.copy() for key, mask in other_masks.items()}) diff --git a/src/ess/reduce/time_of_flight/toa_to_tof.py b/src/ess/reduce/time_of_flight/toa_to_tof.py index 0bb1d50e..558e3083 100644 --- a/src/ess/reduce/time_of_flight/toa_to_tof.py +++ b/src/ess/reduce/time_of_flight/toa_to_tof.py @@ -17,45 +17,21 @@ from .to_events import to_events from .types import ( DistanceResolution, - FastestNeutron, - FrameFoldedTimeOfArrival, - FramePeriod, LookupTableRelativeErrorThreshold, Ltotal, LtotalRange, - MaskedTimeOfFlightLookupTable, - PivotTimeAtDetector, PulsePeriod, PulseStride, PulseStrideOffset, RawData, ResampledTofData, SimulationResults, - TimeOfArrivalMinusPivotTimeModuloPeriod, - TimeOfArrivalResolution, TimeOfFlightLookupTable, + TimeResolution, TofData, - UnwrappedTimeOfArrival, - UnwrappedTimeOfArrivalMinusPivotTime, ) -def frame_period(pulse_period: PulsePeriod, pulse_stride: PulseStride) -> FramePeriod: - """ - Return the period of a frame, which is defined by the pulse period times the pulse - stride. - - Parameters - ---------- - pulse_period: - Period of the source pulses, i.e., time between consecutive pulse starts. - pulse_stride: - Stride of used pulses. Usually 1, but may be a small integer when - pulse-skipping. - """ - return FramePeriod(pulse_period * pulse_stride) - - def extract_ltotal(da: RawData) -> Ltotal: """ Extract the total length of the flight path from the source to the detector from the @@ -70,80 +46,91 @@ def extract_ltotal(da: RawData) -> Ltotal: return Ltotal(da.coords["Ltotal"]) -def compute_tof_lookup_table( +def _mask_large_uncertainty(table: sc.DataArray, error_threshold: float): + """ + Mask regions with large uncertainty with NaNs. + The values are modified in place in the input table. + + Parameters + ---------- + table: + Lookup table with time-of-flight as a function of distance and time-of-arrival. + error_threshold: + Threshold for the relative standard deviation (coefficient of variation) of the + projected time-of-flight above which values are masked. + """ + # Finally, mask regions with large uncertainty with NaNs. + relative_error = sc.stddevs(table.data) / sc.values(table.data) + mask = relative_error > sc.scalar(error_threshold) + # Use numpy for indexing as table is 2D + table.values[mask.values] = np.nan + + +def _compute_mean_tof_in_distance_range( simulation: SimulationResults, - ltotal_range: LtotalRange, - distance_resolution: DistanceResolution, - toa_resolution: TimeOfArrivalResolution, -) -> TimeOfFlightLookupTable: + distance_bins: sc.Variable, + time_bins: sc.Variable, + distance_unit: str, + time_unit: str, + frame_period: sc.Variable, + time_bins_half_width: sc.Variable, +) -> sc.DataArray: """ - Compute a lookup table for time-of-flight as a function of distance and - time-of-arrival. + Compute the mean time-of-flight inside event_time_offset bins for a given range of + distances. Parameters ---------- simulation: Results of a time-of-flight simulation used to create a lookup table. - The results should be a flat table with columns for time-of-arrival, speed, - wavelength, and weight. - ltotal_range: - Range of total flight path lengths from the source to the detector. - distance_resolution: - Resolution of the distance axis in the lookup table. - toa_resolution: - Resolution of the time-of-arrival axis in the lookup table. + distance_bins: + Bin edges for the distance axis in the lookup table. + time_bins: + Bin edges for the event_time_offset axis in the lookup table. + distance_unit: + Unit of the distance axis. + time_unit: + Unit of the event_time_offset axis. + frame_period: + Period of the source pulses, i.e., time between consecutive pulse starts. + time_bins_half_width: + Half the width of the time bins. """ - distance_unit = "m" - res = distance_resolution.to(unit=distance_unit) simulation_distance = simulation.distance.to(unit=distance_unit) - - min_dist, max_dist = ( - x.to(unit=distance_unit) - simulation_distance for x in ltotal_range - ) - # We need to bin the data below, to compute the weighted mean of the wavelength. - # This results in data with bin edges. - # However, the 2d interpolator expects bin centers. - # We want to give the 2d interpolator a table that covers the requested range, - # hence we need to extend the range by at least half a resolution in each direction. - # Then, we make the choice that the resolution in distance is the quantity that - # should be preserved. Because the difference between min and max distance is - # not necessarily an integer multiple of the resolution, we need to add a pad to - # ensure that the last bin is not cut off. We want the upper edge to be higher than - # the maximum distance, hence we pad with an additional 1.5 x resolution. - pad = 2.0 * res - dist_edges = sc.array( - dims=["distance"], - values=np.arange((min_dist - pad).value, (max_dist + pad).value, res.value), - unit=distance_unit, - ) - distances = sc.midpoints(dist_edges) - - time_unit = simulation.time_of_arrival.unit + distances = sc.midpoints(distance_bins) + # Compute arrival and flight times for all neutrons toas = simulation.time_of_arrival + (distances / simulation.speed).to( unit=time_unit, copy=False ) - - # Compute time-of-flight for all neutrons - wavs = sc.broadcast(simulation.wavelength.to(unit="m"), sizes=toas.sizes).flatten( - to="event" - ) - dist = sc.broadcast(distances + simulation_distance, sizes=toas.sizes).flatten( - to="event" - ) - tofs = dist * sc.constants.m_n - tofs *= wavs - tofs /= sc.constants.h + dist = distances + simulation_distance + tofs = dist * (sc.constants.m_n / sc.constants.h) * simulation.wavelength data = sc.DataArray( - data=sc.broadcast(simulation.weight, sizes=toas.sizes).flatten(to="event"), + data=sc.broadcast(simulation.weight, sizes=toas.sizes), coords={ - "toa": toas.flatten(to="event"), + "toa": toas, "tof": tofs.to(unit=time_unit, copy=False), "distance": dist, }, + ).flatten(to="event") + + # Add the event_time_offset coordinate to the data. We first operate on the + # frame period. The table will later be folded to the pulse period. + data.coords['event_time_offset'] = data.coords['toa'] % frame_period + + # Because we staggered the mesh by half a bin width, we want the values above + # the last bin edge to wrap around to the first bin. + # Technically, those values should end up between -0.5*bin_width and 0, but + # a simple modulo also works here because even if they end up between 0 and + # 0.5*bin_width, we are (below) computing the mean between -0.5*bin_width and + # 0.5*bin_width and it yields the same result. + # data.coords['event_time_offset'] %= pulse_period - time_bins_half_width + data.coords['event_time_offset'] %= frame_period - time_bins_half_width + + binned = data.bin( + distance=distance_bins + simulation_distance, event_time_offset=time_bins ) - binned = data.bin(distance=dist_edges + simulation_distance, toa=toa_resolution) # Weighted mean of tof inside each bin mean_tof = ( binned.bins.data * binned.bins.coords["tof"] @@ -154,188 +141,316 @@ def compute_tof_lookup_table( ).bins.sum() / binned.bins.sum() mean_tof.variances = variance.values + return mean_tof - # Convert coordinates to midpoints - mean_tof.coords["toa"] = sc.midpoints(mean_tof.coords["toa"]) - mean_tof.coords["distance"] = sc.midpoints(mean_tof.coords["distance"]) - return TimeOfFlightLookupTable(mean_tof) - - -def masked_tof_lookup_table( - tof_lookup: TimeOfFlightLookupTable, - error_threshold: LookupTableRelativeErrorThreshold, -) -> MaskedTimeOfFlightLookupTable: +def _fold_table_to_pulse_period( + table: sc.DataArray, pulse_period: sc.Variable, pulse_stride: int +) -> sc.DataArray: """ - Mask regions of the lookup table where the variance of the projected time-of-flight - is larger than a given threshold. + Fold the lookup table to the pulse period. We make sure the left and right edges of + the table wrap around the ``event_time_offset`` dimension. Parameters ---------- - tof_lookup: - Lookup table giving time-of-flight as a function of distance and - time-of-arrival. - variance_threshold: - Threshold for the variance of the projected time-of-flight above which regions - are masked. + table: + Lookup table with time-of-flight as a function of distance and time-of-arrival. + pulse_period: + Period of the source pulses, i.e., time between consecutive pulse starts. + pulse_stride: + Stride of used pulses. Usually 1, but may be a small integer when + pulse-skipping. """ - relative_error = sc.stddevs(tof_lookup.data) / sc.values(tof_lookup.data) - mask = relative_error > sc.scalar(error_threshold) - out = tof_lookup.copy() - # Use numpy for indexing as table is 2D - out.values[mask.values] = np.nan - return MaskedTimeOfFlightLookupTable(out) - + size = table.sizes['event_time_offset'] + if (size % pulse_stride) != 0: + raise ValueError( + "TimeOfFlightLookupTable: the number of time bins must be a multiple of " + f"the pulse stride, but got {size} time bins and a pulse stride of " + f"{pulse_stride}." + ) -def find_fastest_neutron(simulation: SimulationResults) -> FastestNeutron: - """ - Find the fastest neutron in the simulation results. - """ - ind = np.argmax(simulation.speed.values) - return FastestNeutron( - time_of_arrival=simulation.time_of_arrival[ind], - speed=simulation.speed[ind], - distance=simulation.distance, + size = size // pulse_stride + out = sc.concat([table, table['event_time_offset', 0]], dim='event_time_offset') + out = sc.concat( + [ + out['event_time_offset', (i * size) : (i + 1) * size + 1] + for i in range(pulse_stride) + ], + dim='pulse', ) - - -def pivot_time_at_detector( - fastest_neutron: FastestNeutron, ltotal: Ltotal -) -> PivotTimeAtDetector: - """ - Compute the pivot time at the detector, i.e., the time of the start of the frame at - the detector. - The assumption here is that the fastest neutron in the simulation results is the one - that arrives at the detector first. - One could have an edge case where a slightly slower neutron which is born earlier - could arrive at the detector first, but this edge case is most probably uncommon, - and the difference in arrival times is likely to be small. - - Parameters - ---------- - fastest_neutron: - Properties of the fastest neutron in the simulation results. - ltotal: - Total length of the flight path from the source to the detector. - """ - dist = ltotal - fastest_neutron.distance.to(unit=ltotal.unit) - toa = fastest_neutron.time_of_arrival + (dist / fastest_neutron.speed).to( - unit=fastest_neutron.time_of_arrival.unit, copy=False + return out.assign_coords( + event_time_offset=sc.concat( + [ + table.coords['event_time_offset']['event_time_offset', :size], + pulse_period, + ], + 'event_time_offset', + ) ) - return PivotTimeAtDetector(toa) -def unwrapped_time_of_arrival( - da: RawData, offset: PulseStrideOffset, pulse_period: PulsePeriod -) -> UnwrappedTimeOfArrival: +def compute_tof_lookup_table( + simulation: SimulationResults, + ltotal_range: LtotalRange, + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, + error_threshold: LookupTableRelativeErrorThreshold, +) -> TimeOfFlightLookupTable: """ - Compute the unwrapped time of arrival of the neutron at the detector. - For event data, this is essentially ``event_time_offset + event_time_zero``. + Compute a lookup table for time-of-flight as a function of distance and + time-of-arrival. Parameters ---------- - da: - Raw detector data loaded from a NeXus file, e.g., NXdetector containing - NXevent_data. - offset: - Integer offset of the first pulse in the stride (typically zero unless we are - using pulse-skipping and the events do not begin with the first pulse in the - stride). + simulation: + Results of a time-of-flight simulation used to create a lookup table. + The results should be a flat table with columns for time-of-arrival, speed, + wavelength, and weight. + ltotal_range: + Range of total flight path lengths from the source to the detector. + distance_resolution: + Resolution of the distance axis in the lookup table. + time_resolution: + Resolution of the time-of-arrival axis in the lookup table. Must be an integer. pulse_period: Period of the source pulses, i.e., time between consecutive pulse starts. + pulse_stride: + Stride of used pulses. Usually 1, but may be a small integer when + pulse-skipping. + error_threshold: + Threshold for the relative standard deviation (coefficient of variation) of the + projected time-of-flight above which values are masked. """ - if da.bins is None: - # 'time_of_flight' is the canonical name in NXmonitor, but in some files, it - # may be called 'tof'. - key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof"})) - toa = da.coords[key] - else: - # To unwrap the time of arrival, we want to add the event_time_zero to the - # event_time_offset. However, we do not really care about the exact datetimes, - # we just want to know the offsets with respect to the start of the run. - # Hence we use the smallest event_time_zero as the time origin. - time_zero = da.coords["event_time_zero"] - da.coords["event_time_zero"].min() - coord = da.bins.coords["event_time_offset"] - unit = elem_unit(coord) - toa = ( - coord - + time_zero.to(dtype=float, unit=unit, copy=False) - - (offset * pulse_period).to(unit=unit, copy=False) + distance_unit = "m" + time_unit = simulation.time_of_arrival.unit + res = distance_resolution.to(unit=distance_unit) + pulse_period = pulse_period.to(unit=time_unit) + frame_period = pulse_period * pulse_stride + + min_dist, max_dist = ( + x.to(unit=distance_unit) - simulation.distance.to(unit=distance_unit) + for x in ltotal_range + ) + # We need to bin the data below, to compute the weighted mean of the wavelength. + # This results in data with bin edges. + # However, the 2d interpolator expects bin centers. + # We want to give the 2d interpolator a table that covers the requested range, + # hence we need to extend the range by at least half a resolution in each direction. + # Then, we make the choice that the resolution in distance is the quantity that + # should be preserved. Because the difference between min and max distance is + # not necessarily an integer multiple of the resolution, we need to add a pad to + # ensure that the last bin is not cut off. We want the upper edge to be higher than + # the maximum distance, hence we pad with an additional 1.5 x resolution. + pad = 2.0 * res + distance_bins = sc.arange('distance', min_dist - pad, max_dist + pad, res) + + # Create some time bins for event_time_offset. + # We want our final table to strictly cover the range [0, frame_period]. + # However, binning the data associates mean values inside the bins to the bin + # centers. Instead, we stagger the mesh by half a bin width so we are computing + # values for the final mesh edges (the bilinear interpolation needs values on the + # edges/corners). + nbins = int(frame_period / time_resolution.to(unit=time_unit)) + 1 + time_bins = sc.linspace( + 'event_time_offset', 0.0, frame_period.value, nbins + 1, unit=pulse_period.unit + ) + time_bins_half_width = 0.5 * (time_bins[1] - time_bins[0]) + time_bins -= time_bins_half_width + + # To avoid a too large RAM usage, we compute the table in chunks, and piece them + # together at the end. + ndist = len(distance_bins) - 1 + max_size = 2e7 + total_size = ndist * len(simulation.time_of_arrival) + nchunks = total_size / max_size + chunk_size = int(ndist / nchunks) + 1 + pieces = [] + for i in range(int(nchunks) + 1): + dist_edges = distance_bins[i * chunk_size : (i + 1) * chunk_size + 1] + + pieces.append( + _compute_mean_tof_in_distance_range( + simulation=simulation, + distance_bins=dist_edges, + time_bins=time_bins, + distance_unit=distance_unit, + time_unit=time_unit, + frame_period=frame_period, + time_bins_half_width=time_bins_half_width, + ) ) - return UnwrappedTimeOfArrival(toa) + table = sc.concat(pieces, 'distance') + table.coords["distance"] = sc.midpoints(table.coords["distance"]) + table.coords["event_time_offset"] = sc.midpoints(table.coords["event_time_offset"]) -def unwrapped_time_of_arrival_minus_frame_pivot_time( - toa: UnwrappedTimeOfArrival, pivot_time: PivotTimeAtDetector -) -> UnwrappedTimeOfArrivalMinusPivotTime: - """ - Compute the time of arrival of the neutron at the detector, unwrapped at the pulse - period, minus the start time of the frame. - We subtract the start time of the frame so that we can use a modulo operation to - wrap the time of arrival at the frame period in the case of pulse-skipping. + table = _fold_table_to_pulse_period( + table=table, pulse_period=pulse_period, pulse_stride=pulse_stride + ) - Parameters - ---------- - toa: - Time of arrival of the neutron at the detector, unwrapped at the pulse period. - pivot_time: - Pivot time at the detector, i.e., the time of the start of the frame at the - detector. - """ - # Order of operation to preserve dimension order - return UnwrappedTimeOfArrivalMinusPivotTime( - -pivot_time.to(unit=elem_unit(toa), copy=False) + toa + # In-place masking for better performance + _mask_large_uncertainty(table, error_threshold) + + return TimeOfFlightLookupTable( + table.transpose(('pulse', 'distance', 'event_time_offset')) ) -def time_of_arrival_minus_pivot_time_modulo_period( - toa_minus_pivot_time: UnwrappedTimeOfArrivalMinusPivotTime, - frame_period: FramePeriod, -) -> TimeOfArrivalMinusPivotTimeModuloPeriod: - """ - Compute the time of arrival of the neutron at the detector, unwrapped at the pulse - period, minus the start time of the frame, modulo the frame period. +def _make_tof_interpolator( + lookup: sc.DataArray, distance_unit: str, time_unit: str +) -> Callable: + from scipy.interpolate import RegularGridInterpolator - Parameters - ---------- - toa_minus_pivot_time: - Time of arrival of the neutron at the detector, unwrapped at the pulse period, - minus the start time of the frame. - frame_period: - Period of the frame, i.e., time between the start of two consecutive frames. - """ - return TimeOfArrivalMinusPivotTimeModuloPeriod( - toa_minus_pivot_time - % frame_period.to(unit=elem_unit(toa_minus_pivot_time), copy=False) + # TODO: to make use of multi-threading, we could write our own interpolator. + # This should be simple enough as we are making the bins linspace, so computing + # bin indices is fast. + + # In the pulse dimension, it could be that for a given event_time_offset and + # distance, a tof value is finite in one pulse and NaN in the other. + # When using the bilinear interpolation, even if the value of the requested point is + # exactly 0 or 1 (in the case of pulse_stride=2), the interpolator will still + # use all 4 corners surrounding the point. This means that if one of the corners + # is NaN, the result will be NaN. + # Here, we use a trick where we duplicate the lookup values in the 'pulse' dimension + # so that the interpolator has values on bin edges for that dimension. + # The interpolator raises an error if axes coordinates are not strictly monotonic, + # so we cannot use e.g. [-0.5, 0.5, 0.5, 1.5] in the case of pulse_stride=2. + # Instead we use [-0.25, 0.25, 0.75, 1.25]. + base_grid = np.arange(float(lookup.sizes["pulse"])) + return RegularGridInterpolator( + ( + np.sort(np.concatenate([base_grid - 0.25, base_grid + 0.25])), + lookup.coords["distance"].to(unit=distance_unit, copy=False).values, + lookup.coords["event_time_offset"].to(unit=time_unit, copy=False).values, + ), + np.repeat(lookup.data.to(unit=time_unit, copy=False).values, 2, axis=0), + method="linear", + bounds_error=False, + fill_value=np.nan, ) -def time_of_arrival_folded_by_frame( - toa: TimeOfArrivalMinusPivotTimeModuloPeriod, - pivot_time: PivotTimeAtDetector, -) -> FrameFoldedTimeOfArrival: - """ - The time of arrival of the neutron at the detector, folded by the frame period. +def _time_of_flight_data_histogram( + da: sc.DataArray, + lookup: sc.DataArray, + ltotal: sc.Variable, + pulse_period: sc.Variable, +) -> sc.DataArray: + # In NeXus, 'time_of_flight' is the canonical name in NXmonitor, but in some files, + # it may be called 'tof'. + key = next(iter(set(da.coords.keys()) & {"time_of_flight", "tof"})) + eto_unit = da.coords[key].unit + pulse_period = pulse_period.to(unit=eto_unit) + + # In histogram mode, because there is a wrap around at the end of the pulse, we + # need to insert a bin edge at that exact location to avoid having the last bin + # with one finite left edge and a NaN right edge (it becomes NaN as it would be + # outside the range of the lookup table). + new_bins = sc.sort( + sc.concat( + [da.coords[key], sc.scalar(0.0, unit=eto_unit), pulse_period], dim=key + ), + key=key, + ) + rebinned = da.rebin({key: new_bins}) + etos = rebinned.coords[key] + + # In histogram mode, the lookup table cannot have a pulse dimension because we + # cannot know in the histogrammed data which pulse the events belong to. + # So we merge the pulse dimension in the lookup table. A quick way to do this + # is to take the mean of the data along the pulse dimension (there should + # only be regions that are NaN in one pulse and finite in the other). + merged = lookup.data.nanmean('pulse') + dim = merged.dims[0] + lookup = sc.DataArray( + data=merged.fold(dim=dim, sizes={'pulse': 1, dim: merged.sizes[dim]}), + coords={ + 'pulse': sc.arange('pulse', 1.0), + 'distance': lookup.coords['distance'], + 'event_time_offset': lookup.coords['event_time_offset'], + }, + ) + pulse_index = sc.zeros(sizes=etos.sizes) - Parameters - ---------- - toa: - Time of arrival of the neutron at the detector, unwrapped at the pulse period, - minus the start time of the frame, modulo the frame period. - pivot_time: - Pivot time at the detector, i.e., the time of the start of the frame at the - detector. - """ - return FrameFoldedTimeOfArrival( - toa + pivot_time.to(unit=elem_unit(toa), copy=False) + # Create 2D interpolator + interp = _make_tof_interpolator( + lookup, distance_unit=ltotal.unit, time_unit=eto_unit + ) + + # Compute time-of-flight of the bin edges using the interpolator + tofs = sc.array( + dims=etos.dims, + values=interp((pulse_index.values, ltotal.values, etos.values)), + unit=eto_unit, + ) + + return rebinned.assign_coords(tof=tofs) + + +def _time_of_flight_data_events( + da: sc.DataArray, + lookup: sc.DataArray, + ltotal: sc.Variable, + pulse_period: sc.Variable, + pulse_stride: int, + pulse_stride_offset: int, +) -> sc.DataArray: + etos = da.bins.coords["event_time_offset"] + eto_unit = elem_unit(etos) + pulse_period = pulse_period.to(unit=eto_unit) + frame_period = pulse_period * pulse_stride + + # TODO: Finding the `tmin` below will not work in the case were data is processed + # in chunks, as taking the minimum time in each chunk will lead to inconsistent + # pulse indices (this will be the case in live data, or when using the + # StreamProcessor). We could instead read it from the first chunk and store it? + + # Compute a pulse index for every event: it is the index of the pulse within a + # frame period. When there is no pulse skipping, those are all zero. When there is + # pulse skipping, the index ranges from zero to pulse_stride - 1. + tmin = da.bins.coords['event_time_zero'].min() + pulse_index = ( + ( + (da.bins.coords['event_time_zero'] - tmin).to(unit=eto_unit) + + 0.5 * pulse_period + ) + % frame_period + ) // pulse_period + # Apply the pulse_stride_offset + pulse_index += pulse_stride_offset + pulse_index %= pulse_stride + + # Create 2D interpolator + interp = _make_tof_interpolator( + lookup, distance_unit=ltotal.unit, time_unit=eto_unit + ) + + # Operate on events (broadcast distances to all events) + ltotal = sc.bins_like(etos, ltotal).bins.constituents["data"] + etos = etos.bins.constituents["data"] + pulse_index = pulse_index.bins.constituents["data"] + + # Compute time-of-flight for all neutrons using the interpolator + tofs = sc.array( + dims=etos.dims, + values=interp((pulse_index.values, ltotal.values, etos.values)), + unit=eto_unit, ) + parts = da.bins.constituents + parts["data"] = tofs + return da.bins.assign_coords(tof=_bins_no_validate(**parts)) + def time_of_flight_data( da: RawData, - lookup: MaskedTimeOfFlightLookupTable, + lookup: TimeOfFlightLookupTable, ltotal: Ltotal, - toas: FrameFoldedTimeOfArrival, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, + pulse_stride_offset: PulseStrideOffset, ) -> TofData: """ Convert the time-of-arrival data to time-of-flight data using a lookup table. @@ -351,39 +466,29 @@ def time_of_flight_data( arrival. ltotal: Total length of the flight path from the source to the detector. - toas: - Time of arrival of the neutron at the detector, folded by the frame period. + pulse_period: + Period of the source pulses, i.e., time between consecutive pulse starts. + pulse_stride: + Stride of used pulses. Usually 1, but may be a small integer when + pulse-skipping. + pulse_stride_offset: + When pulse-skipping, the offset of the first pulse in the stride. This is + typically zero but can be a small integer < pulse_stride. """ - from scipy.interpolate import RegularGridInterpolator - # TODO: to make use of multi-threading, we could write our own interpolator. - # This should be simple enough as we are making the bins linspace, so computing - # bin indices is fast. - f = RegularGridInterpolator( - ( - lookup.coords["toa"].to(unit=elem_unit(toas), copy=False).values, - lookup.coords["distance"].to(unit=ltotal.unit, copy=False).values, - ), - lookup.data.to(unit=elem_unit(toas), copy=False).values.T, - method="linear", - bounds_error=False, - ) - - if da.bins is not None: - ltotal = sc.bins_like(toas, ltotal).bins.constituents["data"] - toas = toas.bins.constituents["data"] - - tofs = sc.array( - dims=toas.dims, values=f((toas.values, ltotal.values)), unit=elem_unit(toas) - ) - - if da.bins is not None: - parts = da.bins.constituents - parts["data"] = tofs - out = da.bins.assign_coords(tof=_bins_no_validate(**parts)) + if da.bins is None: + out = _time_of_flight_data_histogram( + da=da, lookup=lookup, ltotal=ltotal, pulse_period=pulse_period + ) else: - out = da.assign_coords(tof=tofs) - + out = _time_of_flight_data_events( + da=da, + lookup=lookup, + ltotal=ltotal, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + pulse_stride_offset=pulse_stride_offset, + ) return TofData(out) @@ -432,7 +537,7 @@ def default_parameters() -> dict: PulseStride: 1, PulseStrideOffset: 0, DistanceResolution: sc.scalar(0.1, unit="m"), - TimeOfArrivalResolution: 500, + TimeResolution: sc.scalar(250.0, unit='us'), LookupTableRelativeErrorThreshold: 0.1, } @@ -441,101 +546,4 @@ def providers() -> tuple[Callable]: """ Providers of the time-of-flight workflow. """ - return ( - compute_tof_lookup_table, - extract_ltotal, - find_fastest_neutron, - frame_period, - masked_tof_lookup_table, - pivot_time_at_detector, - time_of_arrival_folded_by_frame, - time_of_arrival_minus_pivot_time_modulo_period, - time_of_flight_data, - unwrapped_time_of_arrival, - unwrapped_time_of_arrival_minus_frame_pivot_time, - ) - - -class TofWorkflow: - """ - Helper class to build a time-of-flight workflow and cache the expensive part of the - computation: running the simulation and building the lookup table. - - Parameters - ---------- - simulated_neutrons: - Results of a time-of-flight simulation used to create a lookup table. - The results should be a flat table with columns for time-of-arrival, speed, - wavelength, and weight. - ltotal_range: - Range of total flight path lengths from the source to the detector. - This is used to create the lookup table to compute the neutron - time-of-flight. - Note that the resulting table will extend slightly beyond this range, as the - supplied range is not necessarily a multiple of the distance resolution. - pulse_stride: - Stride of used pulses. Usually 1, but may be a small integer when - pulse-skipping. - pulse_stride_offset: - Integer offset of the first pulse in the stride (typically zero unless we - are using pulse-skipping and the events do not begin with the first pulse in - the stride). - distance_resolution: - Resolution of the distance axis in the lookup table. - Should be a single scalar value with a unit of length. - This is typically of the order of 1-10 cm. - toa_resolution: - Resolution of the time of arrival axis in the lookup table. - Can be an integer (number of bins) or a sc.Variable (bin width). - error_threshold: - Threshold for the variance of the projected time-of-flight above which - regions are masked. - """ - - def __init__( - self, - simulated_neutrons: SimulationResults, - ltotal_range: LtotalRange, - pulse_stride: PulseStride | None = None, - pulse_stride_offset: PulseStrideOffset | None = None, - distance_resolution: DistanceResolution | None = None, - toa_resolution: TimeOfArrivalResolution | None = None, - error_threshold: LookupTableRelativeErrorThreshold | None = None, - ): - import sciline as sl - - self.pipeline = sl.Pipeline(providers()) - self.pipeline[SimulationResults] = simulated_neutrons - self.pipeline[LtotalRange] = ltotal_range - - params = default_parameters() - self.pipeline[PulsePeriod] = params[PulsePeriod] - self.pipeline[PulseStride] = pulse_stride or params[PulseStride] - self.pipeline[PulseStrideOffset] = ( - pulse_stride_offset or params[PulseStrideOffset] - ) - self.pipeline[DistanceResolution] = ( - distance_resolution or params[DistanceResolution] - ) - self.pipeline[TimeOfArrivalResolution] = ( - toa_resolution or params[TimeOfArrivalResolution] - ) - self.pipeline[LookupTableRelativeErrorThreshold] = ( - error_threshold or params[LookupTableRelativeErrorThreshold] - ) - - def cache_results( - self, - results=(SimulationResults, MaskedTimeOfFlightLookupTable, FastestNeutron), - ) -> None: - """ - Cache a list of (usually expensive to compute) intermediate results of the - time-of-flight workflow. - - Parameters - ---------- - results: - List of results to cache. - """ - for t in results: - self.pipeline[t] = self.pipeline.compute(t) + return (compute_tof_lookup_table, extract_ltotal, time_of_flight_data) diff --git a/src/ess/reduce/time_of_flight/types.py b/src/ess/reduce/time_of_flight/types.py index 027c9164..973f3ada 100644 --- a/src/ess/reduce/time_of_flight/types.py +++ b/src/ess/reduce/time_of_flight/types.py @@ -48,17 +48,6 @@ class SimulationResults: distance: sc.Variable -@dataclass -class FastestNeutron: - """ - Properties of the fastest neutron in the simulation results. - """ - - time_of_arrival: sc.Variable - speed: sc.Variable - distance: sc.Variable - - LtotalRange = NewType("LtotalRange", tuple[sc.Variable, sc.Variable]) """ Range (min, max) of the total length of the flight path from the source to the detector. @@ -79,10 +68,16 @@ class FastestNeutron: This is typically of the order of 1-10 cm. """ -TimeOfArrivalResolution = NewType("TimeOfArrivalResolution", int | sc.Variable) +TimeResolution = NewType("TimeResolution", sc.Variable) """ -Resolution of the time of arrival axis in the lookup table. -Can be an integer (number of bins) or a sc.Variable (bin width). +Step size of the event_time_offset axis in the lookup table. +This is basically the 'time-of-flight' resolution of the detector. +Should be a single scalar value with a unit of time. +This is typically of the order of 0.1-0.5 ms. + +Since the event_time_offset range needs to span exactly one pulse period, the final +resolution in the lookup table will be at least the supplied value here, but may be +smaller if the pulse period is not an integer multiple of the time resolution. """ TimeOfFlightLookupTable = NewType("TimeOfFlightLookupTable", sc.DataArray) @@ -90,48 +85,12 @@ class FastestNeutron: Lookup table giving time-of-flight as a function of distance and time of arrival. """ -MaskedTimeOfFlightLookupTable = NewType("MaskedTimeOfFlightLookupTable", sc.DataArray) -""" -Lookup table giving time-of-flight as a function of distance and time of arrival, with -regions of large uncertainty masked out. -""" - LookupTableRelativeErrorThreshold = NewType("LookupTableRelativeErrorThreshold", float) - -FramePeriod = NewType("FramePeriod", sc.Variable) """ -The period of a frame, a (small) integer multiple of the source period. +Threshold for the relative standard deviation (coefficient of variation) of the +projected time-of-flight above which values are masked. """ -UnwrappedTimeOfArrival = NewType("UnwrappedTimeOfArrival", sc.Variable) -""" -Time of arrival of the neutron at the detector, unwrapped at the pulse period. -""" - -PivotTimeAtDetector = NewType("PivotTimeAtDetector", sc.Variable) -""" -Pivot time at the detector, i.e., the time of the start of the frame at the detector. -""" - -UnwrappedTimeOfArrivalMinusPivotTime = NewType( - "UnwrappedTimeOfArrivalMinusPivotTime", sc.Variable -) -""" -Time of arrival of the neutron at the detector, unwrapped at the pulse period, minus -the start time of the frame. -""" - -TimeOfArrivalMinusPivotTimeModuloPeriod = NewType( - "TimeOfArrivalMinusPivotTimeModuloPeriod", sc.Variable -) -""" -Time of arrival of the neutron at the detector minus the start time of the frame, -modulo the frame period. -""" - -FrameFoldedTimeOfArrival = NewType("FrameFoldedTimeOfArrival", sc.Variable) - - PulsePeriod = NewType("PulsePeriod", sc.Variable) """ Period of the source pulses, i.e., time between consecutive pulse starts. @@ -144,7 +103,8 @@ class FastestNeutron: PulseStrideOffset = NewType("PulseStrideOffset", int) """ -When pulse-skipping, the offset of the first pulse in the stride. +When pulse-skipping, the offset of the first pulse in the stride. This is typically +zero but can be a small integer < pulse_stride. """ RawData = NewType("RawData", sc.DataArray) diff --git a/tests/time_of_flight/to_events_test.py b/tests/time_of_flight/to_events_test.py index 114cd91c..52ed7944 100644 --- a/tests/time_of_flight/to_events_test.py +++ b/tests/time_of_flight/to_events_test.py @@ -17,6 +17,29 @@ def test_to_events_1d(): assert sc.allclose(hist.data, result.data) +def test_to_events_1d_with_non_dim_coord(): + table = sc.data.table_xyz(1000) + hist = table.hist(x=20) + hist.coords["y"] = hist.coords["x"] * 2 + events = to_events(hist, "event") + assert "x" not in events.dims + assert "y" not in events.dims + assert "x" in events.coords + assert "y" not in events.coords + result = events.hist(x=hist.coords["x"]) + assert sc.identical(hist.coords["x"], result.coords["x"]) + assert sc.allclose(hist.data, result.data) + + +def test_to_events_1d_scalar_coord_is_preserved(): + table = sc.data.table_xyz(1000) + hist = table.hist(x=20) + hist.coords["y"] = sc.scalar(1.0, unit="m") + events = to_events(hist, "event") + assert "x" not in events.dims + assert "y" in events.coords + + def test_to_events_1d_with_group_coord(): table = sc.data.table_xyz(1000, coord_max=10) table.coords["l"] = table.coords["x"].to(dtype=int) @@ -42,6 +65,23 @@ def test_to_events_2d(): assert sc.allclose(hist.data, result.data) +def test_to_events_2d_with_non_dim_coord(): + table = sc.data.table_xyz(1000) + hist = table.hist(y=20, x=10) + hist.coords["z"] = hist.coords["x"] * 2 + events = to_events(hist, "event") + assert "x" not in events.dims + assert "y" not in events.dims + assert "z" not in events.dims + assert "x" in events.coords + assert "y" in events.coords + assert "z" not in events.coords + result = events.hist(y=hist.coords["y"], x=hist.coords["x"]) + assert sc.identical(hist.coords["x"], result.coords["x"]) + assert sc.identical(hist.coords["y"], result.coords["y"]) + assert sc.allclose(hist.data, result.data) + + def test_to_events_2d_with_group_coord(): table = sc.data.table_xyz(1000, coord_max=10) table.coords["l"] = table.coords["x"].to(dtype=int) diff --git a/tests/time_of_flight/unwrap_test.py b/tests/time_of_flight/unwrap_test.py index d351cb6d..5744b199 100644 --- a/tests/time_of_flight/unwrap_test.py +++ b/tests/time_of_flight/unwrap_test.py @@ -3,7 +3,6 @@ import numpy as np import pytest import scipp as sc -from scipp.testing import assert_identical from scippneutron.conversion.graph.beamline import beamline as beamline_graph from scippneutron.conversion.graph.tof import elastic as elastic_graph @@ -13,23 +12,6 @@ sl = pytest.importorskip("sciline") -def test_frame_period_is_pulse_period_if_not_pulse_skipping() -> None: - pl = sl.Pipeline(time_of_flight.providers()) - period = sc.scalar(123.0, unit="ms") - pl[time_of_flight.PulsePeriod] = period - pl[time_of_flight.PulseStride] = 1 - assert_identical(pl.compute(time_of_flight.FramePeriod), period) - - -@pytest.mark.parametrize("stride", [1, 2, 3, 4]) -def test_frame_period_is_multiple_pulse_period_if_pulse_skipping(stride) -> None: - pl = sl.Pipeline(time_of_flight.providers()) - period = sc.scalar(123.0, unit="ms") - pl[time_of_flight.PulsePeriod] = period - pl[time_of_flight.PulseStride] = stride - assert_identical(pl.compute(time_of_flight.FramePeriod), stride * period) - - def test_unwrap_with_no_choppers() -> None: # At this small distance the frames are not overlapping (with the given wavelength # range), despite not using any choppers. @@ -40,6 +22,7 @@ def test_unwrap_with_no_choppers() -> None: monitors={"detector": distance}, run_length=sc.scalar(1 / 14, unit="s") * 4, events_per_pulse=100_000, + seed=1, ) mon, ref = beamline.get_monitor("detector") @@ -60,7 +43,6 @@ def test_unwrap_with_no_choppers() -> None: # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value - ref = ref.bins.concat().value diff = abs( (wavs.coords["wavelength"] - ref.coords["wavelength"]) @@ -68,6 +50,9 @@ def test_unwrap_with_no_choppers() -> None: ) # Most errors should be small assert np.nanpercentile(diff.values, 96) < 1.0 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3)) # At 80m, event_time_offset does not wrap around (all events are within the same pulse). @@ -81,6 +66,7 @@ def test_standard_unwrap(dist) -> None: monitors={"detector": distance}, run_length=sc.scalar(1 / 14, unit="s") * 4, events_per_pulse=100_000, + seed=2, ) mon, ref = beamline.get_monitor("detector") @@ -101,7 +87,6 @@ def test_standard_unwrap(dist) -> None: # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value - ref = ref.bins.concat().value diff = abs( (wavs.coords["wavelength"] - ref.coords["wavelength"]) @@ -109,6 +94,9 @@ def test_standard_unwrap(dist) -> None: ) # All errors should be small assert np.nanpercentile(diff.values, 100) < 0.01 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3)) # At 80m, event_time_offset does not wrap around (all events are within the same pulse). @@ -123,17 +111,14 @@ def test_standard_unwrap_histogram_mode(dist, dim) -> None: monitors={"detector": distance}, run_length=sc.scalar(1 / 14, unit="s") * 4, events_per_pulse=100_000, + seed=3, ) mon, ref = beamline.get_monitor("detector") - mon = ( - mon.hist( - event_time_offset=sc.linspace( - "event_time_offset", 0.0, 1000.0 / 14, num=1001, unit="ms" - ).to(unit="s") - ) - .sum("pulse") - .rename(event_time_offset=dim) - ) + mon = mon.hist( + event_time_offset=sc.linspace( + "event_time_offset", 0.0, 1000.0 / 14, num=1001, unit="ms" + ).to(unit=mon.bins.coords["event_time_offset"].bins.unit) + ).rename(event_time_offset=dim) sim = time_of_flight.simulate_beamline( choppers=choppers, neutrons=300_000, seed=1234 @@ -150,23 +135,27 @@ def test_standard_unwrap_histogram_mode(dist, dim) -> None: tofs = pl.compute(time_of_flight.ResampledTofData) graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = tofs.transform_coords("wavelength", graph=graph) - ref = ref.bins.concat().value.hist(wavelength=wavs.coords["wavelength"]) + ref = ref.hist(wavelength=wavs.coords["wavelength"]) # We divide by the maximum to avoid large relative differences at the edges of the # frames where the counts are low. diff = (wavs - ref) / ref.max() assert np.nanpercentile(diff.values, 96.0) < 0.3 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3)) def test_pulse_skipping_unwrap() -> None: distance = sc.scalar(100.0, unit="m") choppers = fakes.psc_choppers() - choppers["pulse_skipping"] = fakes.pulse_skipping + choppers["pulse_skipping"] = fakes.pulse_skipping_chopper() beamline = fakes.FakeBeamline( choppers=choppers, monitors={"detector": distance}, run_length=sc.scalar(1.0, unit="s"), events_per_pulse=100_000, + seed=4, ) mon, ref = beamline.get_monitor("detector") @@ -188,7 +177,6 @@ def test_pulse_skipping_unwrap() -> None: # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value - ref = ref.bins.concat().value diff = abs( (wavs.coords["wavelength"] - ref.coords["wavelength"]) @@ -196,18 +184,68 @@ def test_pulse_skipping_unwrap() -> None: ) # All errors should be small assert np.nanpercentile(diff.values, 100) < 0.01 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3)) + + +def test_pulse_skipping_unwrap_180_phase_shift() -> None: + distance = sc.scalar(100.0, unit="m") + choppers = fakes.psc_choppers() + choppers["pulse_skipping"] = fakes.pulse_skipping_chopper() + choppers["pulse_skipping"].phase.value += 180.0 + + beamline = fakes.FakeBeamline( + choppers=choppers, + monitors={"detector": distance}, + run_length=sc.scalar(1.0, unit="s"), + events_per_pulse=100_000, + seed=4, + ) + mon, ref = beamline.get_monitor("detector") + + sim = time_of_flight.simulate_beamline( + choppers=choppers, neutrons=300_000, pulses=2, seed=1234 + ) + + pl = sl.Pipeline( + time_of_flight.providers(), params=time_of_flight.default_parameters() + ) + + pl[time_of_flight.RawData] = mon + pl[time_of_flight.SimulationResults] = sim + pl[time_of_flight.LtotalRange] = distance, distance + pl[time_of_flight.PulseStride] = 2 + pl[time_of_flight.PulseStrideOffset] = 1 # Start the stride at the second pulse + + tofs = pl.compute(time_of_flight.TofData) + + # Convert to wavelength + graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} + wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value + + diff = abs( + (wavs.coords["wavelength"] - ref.coords["wavelength"]) + / ref.coords["wavelength"] + ) + # All errors should be small + assert np.nanpercentile(diff.values, 100) < 0.01 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3)) def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse() -> None: distance = sc.scalar(150.0, unit="m") choppers = fakes.psc_choppers() - choppers["pulse_skipping"] = fakes.pulse_skipping + choppers["pulse_skipping"] = fakes.pulse_skipping_chopper() beamline = fakes.FakeBeamline( choppers=choppers, monitors={"detector": distance}, run_length=sc.scalar(1.0, unit="s"), events_per_pulse=100_000, + seed=5, ) mon, ref = beamline.get_monitor("detector") @@ -230,7 +268,6 @@ def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse() -> # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value - ref = ref.bins.concat().value diff = abs( (wavs.coords["wavelength"] - ref.coords["wavelength"]) @@ -238,18 +275,22 @@ def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse() -> ) # All errors should be small assert np.nanpercentile(diff.values, 100) < 0.01 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3)) def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing() -> None: distance = sc.scalar(100.0, unit="m") choppers = fakes.psc_choppers() - choppers["pulse_skipping"] = fakes.pulse_skipping + choppers["pulse_skipping"] = fakes.pulse_skipping_chopper() beamline = fakes.FakeBeamline( choppers=choppers, monitors={"detector": distance}, run_length=sc.scalar(1.0, unit="s"), events_per_pulse=100_000, + seed=6, ) mon, ref = beamline.get_monitor("detector") @@ -261,9 +302,10 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing() -> No time_of_flight.providers(), params=time_of_flight.default_parameters() ) - pl[time_of_flight.RawData] = mon[ - 1: - ].copy() # Skip first pulse = half of the first frame + # Skip first pulse = half of the first frame + a = mon.group('event_time_zero')['event_time_zero', 1:] + a.bins.coords['event_time_zero'] = sc.bins_like(a, a.coords['event_time_zero']) + pl[time_of_flight.RawData] = a.bins.concat('event_time_zero') pl[time_of_flight.SimulationResults] = sim pl[time_of_flight.LtotalRange] = distance, distance pl[time_of_flight.PulseStride] = 2 @@ -274,7 +316,24 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing() -> No # Convert to wavelength graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = tofs.transform_coords("wavelength", graph=graph).bins.concat().value - ref = ref[1:].copy().bins.concat().value + # Bin the events in toa starting from the pulse period to skip the first pulse. + ref = ( + ref.bin( + toa=sc.concat( + [ + sc.scalar(1 / 14, unit='s').to(unit=ref.coords['toa'].unit), + ref.coords['toa'].max() * 1.01, + ], + dim='toa', + ) + ) + .bins.concat() + .value + ) + + # Sort the events according id to make sure we are comparing the same values. + wavs = sc.sort(wavs, key=wavs.coords['id']) + ref = sc.sort(ref, key=ref.coords['id']) diff = abs( (wavs.coords["wavelength"] - ref.coords["wavelength"]) @@ -282,29 +341,33 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing() -> No ) # All errors should be small assert np.nanpercentile(diff.values, 100) < 0.01 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose( + pl.compute(time_of_flight.RawData).data.nansum(), + tofs.data.nansum(), + rtol=sc.scalar(1.0e-3), + ) def test_pulse_skipping_unwrap_histogram_mode() -> None: distance = sc.scalar(100.0, unit="m") choppers = fakes.psc_choppers() - choppers["pulse_skipping"] = fakes.pulse_skipping + choppers["pulse_skipping"] = fakes.pulse_skipping_chopper() beamline = fakes.FakeBeamline( choppers=choppers, monitors={"detector": distance}, run_length=sc.scalar(1.0, unit="s"), events_per_pulse=100_000, + seed=7, ) mon, ref = beamline.get_monitor("detector") - mon = ( - mon.hist( - event_time_offset=sc.linspace( - "event_time_offset", 0.0, 1000.0 / 14, num=1001, unit="ms" - ).to(unit="s") - ) - .sum("pulse") - .rename(event_time_offset="time_of_flight") - ) + mon = mon.hist( + event_time_offset=sc.linspace( + "event_time_offset", 0.0, 1000.0 / 14, num=1001, unit="ms" + ).to(unit=mon.bins.coords["event_time_offset"].bins.unit) + ).rename(event_time_offset="time_of_flight") sim = time_of_flight.simulate_beamline( choppers=choppers, neutrons=300_000, seed=1234 @@ -322,8 +385,11 @@ def test_pulse_skipping_unwrap_histogram_mode() -> None: tofs = pl.compute(time_of_flight.ResampledTofData) graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = tofs.transform_coords("wavelength", graph=graph) - ref = ref.bins.concat().value.hist(wavelength=wavs.coords["wavelength"]) + ref = ref.hist(wavelength=wavs.coords["wavelength"]) # We divide by the maximum to avoid large relative differences at the edges of the # frames where the counts are low. diff = (wavs - ref) / ref.max() assert np.nanpercentile(diff.values, 96.0) < 0.3 + # Make sure that we have not lost too many events (we lose some because they may be + # given a NaN tof from the lookup). + assert sc.isclose(mon.data.nansum(), tofs.data.nansum(), rtol=sc.scalar(1.0e-3)) diff --git a/tests/time_of_flight/wfm_test.py b/tests/time_of_flight/wfm_test.py index a29948ad..18034ac5 100644 --- a/tests/time_of_flight/wfm_test.py +++ b/tests/time_of_flight/wfm_test.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 Scipp contributors (https://github.com/scipp) -from functools import partial - +import numpy as np import pytest import scipp as sc -import tof as tof_pkg from scippneutron.chopper import DiskChopper -from scippneutron.conversion.graph.beamline import beamline -from scippneutron.conversion.graph.tof import elastic +from scippneutron.conversion.graph.beamline import beamline as beamline_graph +from scippneutron.conversion.graph.tof import elastic as elastic_graph from ess.reduce import time_of_flight from ess.reduce.time_of_flight import fakes @@ -113,7 +111,6 @@ def simulation_dream_choppers(): ) -@pytest.mark.parametrize("npulses", [1, 2]) @pytest.mark.parametrize( "ltotal", [ @@ -128,88 +125,58 @@ def simulation_dream_choppers(): ) @pytest.mark.parametrize("time_offset_unit", ["s", "ms", "us", "ns"]) @pytest.mark.parametrize("distance_unit", ["m", "mm"]) -def test_dream_wfm( - simulation_dream_choppers, - npulses, - ltotal, - time_offset_unit, - distance_unit, -): +def test_dream_wfm(simulation_dream_choppers, ltotal, time_offset_unit, distance_unit): monitors = { f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to="detector")) } # Create some neutron events - wavelengths = sc.array( - dims=["event"], values=[1.5, 1.6, 1.7, 3.3, 3.4, 3.5], unit="angstrom" - ) - birth_times = sc.full(sizes=wavelengths.sizes, value=1.5, unit="ms") - ess_beamline = fakes.FakeBeamline( + beamline = fakes.FakeBeamline( choppers=dream_choppers(), monitors=monitors, - run_length=sc.scalar(1 / 14, unit="s") * npulses, - events_per_pulse=len(wavelengths), - source=partial( - tof_pkg.Source.from_neutrons, - birth_times=birth_times, - wavelengths=wavelengths, - frequency=sc.scalar(14.0, unit="Hz"), - ), + run_length=sc.scalar(1 / 14, unit="s") * 4, + events_per_pulse=10_000, + seed=77, ) - # Save the true wavelengths for later - true_wavelengths = ess_beamline.source.data.coords["wavelength"] - - raw_data = sc.concat( - [ess_beamline.get_monitor(key)[0] for key in monitors.keys()], + raw = sc.concat( + [beamline.get_monitor(key)[0].squeeze() for key in monitors.keys()], dim="detector", ).fold(dim="detector", sizes=ltotal.sizes) # Convert the time offset to the unit requested by the test - raw_data.bins.coords["event_time_offset"] = raw_data.bins.coords[ - "event_time_offset" - ].to(unit=time_offset_unit, copy=False) - - raw_data.coords["Ltotal"] = ltotal.to(unit=distance_unit, copy=False) - - # Verify that all 6 neutrons made it through the chopper cascade - assert sc.identical( - raw_data.bins.concat("pulse").hist().data, - sc.array( - dims=["detector"], - values=[len(wavelengths) * npulses] * len(monitors), - unit="counts", - dtype="float64", - ).fold(dim="detector", sizes=ltotal.sizes), + raw.bins.coords["event_time_offset"] = raw.bins.coords["event_time_offset"].to( + unit=time_offset_unit, copy=False ) + # Convert the distance to the unit requested by the test + raw.coords["Ltotal"] = raw.coords["Ltotal"].to(unit=distance_unit, copy=False) + + # Save reference data + ref = beamline.get_monitor(next(iter(monitors)))[1].squeeze() + ref = sc.sort(ref, key='id') - # Set up the workflow - workflow = sl.Pipeline( + pl = sl.Pipeline( time_of_flight.providers(), params=time_of_flight.default_parameters() ) - workflow[time_of_flight.RawData] = raw_data - workflow[time_of_flight.SimulationResults] = simulation_dream_choppers - workflow[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() + pl[time_of_flight.RawData] = raw + pl[time_of_flight.SimulationResults] = simulation_dream_choppers + pl[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() - # Compute time-of-flight - tofs = workflow.compute(time_of_flight.TofData) - assert {dim: tofs.sizes[dim] for dim in ltotal.sizes} == ltotal.sizes + tofs = pl.compute(time_of_flight.TofData) # Convert to wavelength - graph = {**beamline(scatter=False), **elastic("tof")} - wav_wfm = tofs.transform_coords("wavelength", graph=graph) + graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} + wavs = tofs.transform_coords("wavelength", graph=graph) - # Compare the computed wavelengths to the true wavelengths - for i in range(npulses): - result = wav_wfm["pulse", i].flatten(to="detector") - for j in range(len(result)): - computed_wavelengths = result[j].values.coords["wavelength"] - assert sc.allclose( - computed_wavelengths, - true_wavelengths["pulse", i], - rtol=sc.scalar(1e-02), - ) + for da in wavs.flatten(to='pixel'): + x = sc.sort(da.value, key='id') + diff = abs( + (x.coords["wavelength"] - ref.coords["wavelength"]) + / ref.coords["wavelength"] + ) + assert np.nanpercentile(diff.values, 100) < 0.02 + assert sc.isclose(ref.data.sum(), da.data.sum(), rtol=sc.scalar(1.0e-3)) @pytest.fixture(scope="module") @@ -219,7 +186,6 @@ def simulation_dream_choppers_time_overlap(): ) -@pytest.mark.parametrize("npulses", [1, 2]) @pytest.mark.parametrize( "ltotal", [ @@ -236,7 +202,6 @@ def simulation_dream_choppers_time_overlap(): @pytest.mark.parametrize("distance_unit", ["m", "mm"]) def test_dream_wfm_with_subframe_time_overlap( simulation_dream_choppers_time_overlap, - npulses, ltotal, time_offset_unit, distance_unit, @@ -246,91 +211,56 @@ def test_dream_wfm_with_subframe_time_overlap( } # Create some neutron events - wavelengths = [1.5, 1.6, 1.7, 3.3, 3.4, 3.5] - birth_times = [1.5] * len(wavelengths) - - # Add overlap neutrons - birth_times.extend([0.0, 3.3]) - wavelengths.extend([2.6, 2.4]) - - wavelengths = sc.array(dims=["event"], values=wavelengths, unit="angstrom") - birth_times = sc.array(dims=["event"], values=birth_times, unit="ms") - - ess_beamline = fakes.FakeBeamline( + beamline = fakes.FakeBeamline( choppers=dream_choppers_with_frame_overlap(), monitors=monitors, - run_length=sc.scalar(1 / 14, unit="s") * npulses, - events_per_pulse=len(wavelengths), - source=partial( - tof_pkg.Source.from_neutrons, - birth_times=birth_times, - wavelengths=wavelengths, - frequency=sc.scalar(14.0, unit="Hz"), - ), + run_length=sc.scalar(1 / 14, unit="s") * 4, + events_per_pulse=10_000, + seed=88, ) - # Save the true wavelengths for later - true_wavelengths = ess_beamline.source.data.coords["wavelength"] - - raw_data = sc.concat( - [ess_beamline.get_monitor(key)[0] for key in monitors.keys()], + raw = sc.concat( + [beamline.get_monitor(key)[0].squeeze() for key in monitors.keys()], dim="detector", ).fold(dim="detector", sizes=ltotal.sizes) # Convert the time offset to the unit requested by the test - raw_data.bins.coords["event_time_offset"] = raw_data.bins.coords[ - "event_time_offset" - ].to(unit=time_offset_unit, copy=False) - - raw_data.coords["Ltotal"] = ltotal.to(unit=distance_unit, copy=False) - - # Verify that all 6 neutrons made it through the chopper cascade - assert sc.identical( - raw_data.bins.concat("pulse").hist().data, - sc.array( - dims=["detector"], - values=[len(wavelengths) * npulses] * len(monitors), - unit="counts", - dtype="float64", - ).fold(dim="detector", sizes=ltotal.sizes), + raw.bins.coords["event_time_offset"] = raw.bins.coords["event_time_offset"].to( + unit=time_offset_unit, copy=False ) + # Convert the distance to the unit requested by the test + raw.coords["Ltotal"] = raw.coords["Ltotal"].to(unit=distance_unit, copy=False) + + # Save reference data + ref = beamline.get_monitor(next(iter(monitors)))[1].squeeze() + ref = sc.sort(ref, key='id') - # Set up the workflow - workflow = sl.Pipeline( + pl = sl.Pipeline( time_of_flight.providers(), params=time_of_flight.default_parameters() ) - workflow[time_of_flight.RawData] = raw_data - workflow[time_of_flight.SimulationResults] = simulation_dream_choppers_time_overlap - workflow[time_of_flight.LookupTableRelativeErrorThreshold] = 0.01 - workflow[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() - - # Make sure some values in the lookup table have been masked (turned to NaNs) - original_table = workflow.compute(time_of_flight.TimeOfFlightLookupTable) - masked_table = workflow.compute(time_of_flight.MaskedTimeOfFlightLookupTable) - assert sc.isnan(masked_table).data.sum() > sc.isnan(original_table).data.sum() + pl[time_of_flight.RawData] = raw + pl[time_of_flight.SimulationResults] = simulation_dream_choppers_time_overlap + pl[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() + pl[time_of_flight.LookupTableRelativeErrorThreshold] = 0.01 - # Compute time-of-flight - tofs = workflow.compute(time_of_flight.TofData) - assert {dim: tofs.sizes[dim] for dim in ltotal.sizes} == ltotal.sizes + tofs = pl.compute(time_of_flight.TofData) # Convert to wavelength - graph = {**beamline(scatter=False), **elastic("tof")} - wav_wfm = tofs.transform_coords("wavelength", graph=graph) - - # Compare the computed wavelengths to the true wavelengths - for i in range(npulses): - result = wav_wfm["pulse", i].flatten(to="detector") - for j in range(len(result)): - computed_wavelengths = result[j].values.coords["wavelength"] - assert sc.allclose( - computed_wavelengths[:-2], - true_wavelengths["pulse", i][:-2], - rtol=sc.scalar(1e-02), - ) - # The two neutrons in the overlap region should have NaN wavelengths - assert sc.isnan(computed_wavelengths[-2]) - assert sc.isnan(computed_wavelengths[-1]) + graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} + wavs = tofs.transform_coords("wavelength", graph=graph) + + for da in wavs.flatten(to='pixel'): + x = sc.sort(da.value, key='id') + sel = sc.isfinite(x.coords["wavelength"]) + y = ref.coords["wavelength"][sel] + diff = abs((x.coords["wavelength"][sel] - y) / y) + assert np.nanpercentile(diff.values, 100) < 0.02 + sum_wfm = da.hist(wavelength=100).data.sum() + sum_ref = ref.hist(wavelength=100).data.sum() + # Verify that we lost some neutrons that were in the overlapping region + assert sum_wfm < sum_ref + assert sum_wfm > sum_ref * 0.9 @pytest.fixture(scope="module") @@ -340,7 +270,6 @@ def simulation_v20_choppers(): ) -@pytest.mark.parametrize("npulses", [1, 2]) @pytest.mark.parametrize( "ltotal", [ @@ -354,80 +283,56 @@ def simulation_v20_choppers(): @pytest.mark.parametrize("time_offset_unit", ["s", "ms", "us", "ns"]) @pytest.mark.parametrize("distance_unit", ["m", "mm"]) def test_v20_compute_wavelengths_from_wfm( - simulation_v20_choppers, npulses, ltotal, time_offset_unit, distance_unit + simulation_v20_choppers, ltotal, time_offset_unit, distance_unit ): monitors = { f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to="detector")) } # Create some neutron events - wavelengths = sc.array( - dims=["event"], values=[2.75, 4.2, 5.4, 6.5, 7.6, 8.75], unit="angstrom" - ) - birth_times = sc.full(sizes=wavelengths.sizes, value=1.5, unit="ms") - ess_beamline = fakes.FakeBeamline( + beamline = fakes.FakeBeamline( choppers=fakes.wfm_choppers(), monitors=monitors, - run_length=sc.scalar(1 / 14, unit="s") * npulses, - events_per_pulse=len(wavelengths), - source=partial( - tof_pkg.Source.from_neutrons, - birth_times=birth_times, - wavelengths=wavelengths, - frequency=sc.scalar(14.0, unit="Hz"), - ), + run_length=sc.scalar(1 / 14, unit="s") * 4, + events_per_pulse=10_000, + seed=99, ) - # Save the true wavelengths for later - true_wavelengths = ess_beamline.source.data.coords["wavelength"] - - raw_data = sc.concat( - [ess_beamline.get_monitor(key)[0] for key in monitors.keys()], + raw = sc.concat( + [beamline.get_monitor(key)[0].squeeze() for key in monitors.keys()], dim="detector", ).fold(dim="detector", sizes=ltotal.sizes) # Convert the time offset to the unit requested by the test - raw_data.bins.coords["event_time_offset"] = raw_data.bins.coords[ - "event_time_offset" - ].to(unit=time_offset_unit, copy=False) - - raw_data.coords["Ltotal"] = ltotal.to(unit=distance_unit, copy=False) - - # Verify that all 6 neutrons made it through the chopper cascade - assert sc.identical( - raw_data.bins.concat("pulse").hist().data, - sc.array( - dims=["detector"], - values=[len(wavelengths) * npulses] * len(monitors), - unit="counts", - dtype="float64", - ).fold(dim="detector", sizes=ltotal.sizes), + raw.bins.coords["event_time_offset"] = raw.bins.coords["event_time_offset"].to( + unit=time_offset_unit, copy=False ) + # Convert the distance to the unit requested by the test + raw.coords["Ltotal"] = raw.coords["Ltotal"].to(unit=distance_unit, copy=False) + + # Save reference data + ref = beamline.get_monitor(next(iter(monitors)))[1].squeeze() + ref = sc.sort(ref, key='id') - # Set up the workflow - workflow = sl.Pipeline( + pl = sl.Pipeline( time_of_flight.providers(), params=time_of_flight.default_parameters() ) - workflow[time_of_flight.RawData] = raw_data - workflow[time_of_flight.SimulationResults] = simulation_v20_choppers - workflow[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() + pl[time_of_flight.RawData] = raw + pl[time_of_flight.SimulationResults] = simulation_v20_choppers + pl[time_of_flight.LtotalRange] = ltotal.min(), ltotal.max() - # Compute time-of-flight - tofs = workflow.compute(time_of_flight.TofData) - assert {dim: tofs.sizes[dim] for dim in ltotal.sizes} == ltotal.sizes + tofs = pl.compute(time_of_flight.TofData) # Convert to wavelength - graph = {**beamline(scatter=False), **elastic("tof")} - wav_wfm = tofs.transform_coords("wavelength", graph=graph) - - # Compare the computed wavelengths to the true wavelengths - for i in range(npulses): - result = wav_wfm["pulse", i].flatten(to="detector") - for j in range(len(result)): - computed_wavelengths = result[j].values.coords["wavelength"] - assert sc.allclose( - computed_wavelengths, - true_wavelengths["pulse", i], - rtol=sc.scalar(1e-02), - ) + graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} + wavs = tofs.transform_coords("wavelength", graph=graph) + + for da in wavs.flatten(to='pixel'): + x = sc.sort(da.value, key='id') + diff = abs( + (x.coords["wavelength"] - ref.coords["wavelength"]) + / ref.coords["wavelength"] + ) + assert np.nanpercentile(diff.values, 99) < 0.02 + assert sc.isclose(ref.data.sum(), da.data.sum(), rtol=sc.scalar(1.0e-3))