Skip to content

Commit

Permalink
Merge pull request #203 from scipp/load_dream_csv
Browse files Browse the repository at this point in the history
Add dream.load_geant4_csv
  • Loading branch information
jl-wynen committed Oct 3, 2023
2 parents 1ff1045 + d63d207 commit 7025e17
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 0 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- ipykernel==6.25.1
- ipywidgets==8.1.0
- nbsphinx=0.9.2
- pandas=2.0.3
- pandoc=3.1.3
- pip=23.2.1
- plopp=23.09.0
Expand Down
5 changes: 5 additions & 0 deletions src/ess/dream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from . import data
from .io import load_geant4_csv

__all__ = ['data', 'load_geant4_csv']
32 changes: 32 additions & 0 deletions src/ess/dream/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
_version = '1'

__all__ = ['get_path']


def _make_pooch():
import pooch

return pooch.create(
path=pooch.os_cache('ess/dream'),
env='ESS_DREAM_DATA_DIR',
base_url='https://public.esss.dk/groups/scipp/ess/dream/{version}/',
version=_version,
registry={
'data_dream_with_sectors.csv.zip': 'md5:52ae6eb3705e5e54306a001bc0ae85d8',
},
)


_pooch = _make_pooch()


def get_path(name: str) -> str:
"""
Return the path to a data file bundled with scippneutron.
This function only works with example data and cannot handle
paths to custom files.
"""
return _pooch.fetch(name)
6 changes: 6 additions & 0 deletions src/ess/dream/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from .geant4 import load_geant4_csv

__all__ = ['load_geant4_csv']
117 changes: 117 additions & 0 deletions src/ess/dream/io/geant4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

import os
from io import BytesIO, StringIO
from typing import Dict, Optional, Union

import numpy as np
import scipp as sc

MANTLE_DETECTOR_ID = sc.index(7)
HIGH_RES_DETECTOR_ID = sc.index(8)
ENDCAPS_DETECTOR_IDS = tuple(map(sc.index, (3, 4, 5, 6)))


def load_geant4_csv(
filename: Union[str, os.PathLike, StringIO, BytesIO]
) -> sc.DataGroup:
"""Load a GEANT4 CSV file for DREAM.
Parameters
----------
filename:
Path to the GEANT4 CSV file.
Returns
-------
:
A :class:`scipp.DataGroup` containing the loaded events.
"""
events = _load_raw_events(filename)
detectors = _split_detectors(events)
for det in detectors.values():
_adjust_coords(det)
detectors = _group(detectors)

return sc.DataGroup({'instrument': sc.DataGroup(detectors)})


def _load_raw_events(
filename: Union[str, os.PathLike, StringIO, BytesIO]
) -> sc.DataArray:
table = sc.io.load_csv(filename, sep='\t', header_parser='bracket', data_columns=[])
table = table.rename_dims(row='event')
return sc.DataArray(
sc.ones(sizes=table.sizes, with_variances=True, unit='counts'),
coords=table.coords,
)


def _adjust_coords(da: sc.DataArray) -> None:
da.coords['wavelength'] = da.coords.pop('lambda')
da.coords['position'] = sc.spatial.as_vectors(
da.coords.pop('x_pos'), da.coords.pop('y_pos'), da.coords.pop('z_pos')
)


def _group(detectors: Dict[str, sc.DataArray]) -> Dict[str, sc.DataArray]:
elements = ('module', 'segment', 'counter', 'wire', 'strip')

def group(key: str, da: sc.DataArray) -> sc.DataArray:
if key == 'high_resolution':
# Only the HR detector has sectors.
return da.group('sector', *elements)
res = da.group(*elements)
res.bins.coords.pop('sector', None)
return res

return {key: group(key, da) for key, da in detectors.items()}


def _split_detectors(
data: sc.DataArray, detector_id_name: str = 'det ID'
) -> Dict[str, sc.DataArray]:
groups = data.group(
sc.concat(
[MANTLE_DETECTOR_ID, HIGH_RES_DETECTOR_ID, *ENDCAPS_DETECTOR_IDS],
dim=detector_id_name,
)
)
detectors = {}
if (
mantle := _extract_detector(groups, detector_id_name, MANTLE_DETECTOR_ID)
) is not None:
detectors['mantle'] = mantle.copy()
if (
high_res := _extract_detector(groups, detector_id_name, HIGH_RES_DETECTOR_ID)
) is not None:
detectors['high_resolution'] = high_res.copy()

endcaps_list = [
det
for i in ENDCAPS_DETECTOR_IDS
if (det := _extract_detector(groups, detector_id_name, i)) is not None
]
if endcaps_list:
endcaps = sc.concat(endcaps_list, data.dim)
endcaps = endcaps.bin(
z_pos=sc.array(
dims=['z_pos'],
values=[-np.inf, 0.0, np.inf],
unit=endcaps.coords['z_pos'].unit,
)
)
detectors['endcap_backward'] = endcaps[0].bins.concat().value.copy()
detectors['endcap_forward'] = endcaps[1].bins.concat().value.copy()

return detectors


def _extract_detector(
detector_groups: sc.DataArray, detector_id_name: str, detector_id: sc.Variable
) -> Optional[sc.DataArray]:
try:
return detector_groups[detector_id_name, detector_id].value
except IndexError:
return None
Empty file added tests/dream/__init__.py
Empty file.
Empty file added tests/dream/io/__init__.py
Empty file.
126 changes: 126 additions & 0 deletions tests/dream/io/geant4_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

import zipfile
from io import BytesIO
from typing import Optional, Set

import numpy as np
import pytest
import scipp as sc
import scipp.testing

from ess.dream import data, load_geant4_csv


@pytest.fixture(scope='module')
def file_path():
return data.get_path('data_dream_with_sectors.csv.zip')


# Load file into memory only once
@pytest.fixture(scope='module')
def load_file(file_path):
with zipfile.ZipFile(file_path, 'r') as archive:
return archive.read(archive.namelist()[0])


@pytest.fixture(scope='function')
def file(load_file):
return BytesIO(load_file)


def assert_index_coord(
coord: sc.Variable, *, values: Optional[Set[int]] = None
) -> None:
assert coord.ndim == 1
assert coord.unit is None
assert coord.dtype == 'int64'
if values is not None:
assert set(np.unique(coord.values)) == values


def test_load_geant4_csv_loads_expected_structure(file):
loaded = load_geant4_csv(file)
assert isinstance(loaded, sc.DataGroup)
assert loaded.keys() == {'instrument'}

instrument = loaded['instrument']
assert isinstance(instrument, sc.DataGroup)
assert instrument.keys() == {
'mantle',
'high_resolution',
'endcap_forward',
'endcap_backward',
}


@pytest.mark.parametrize(
'key', ('mantle', 'high_resolution', 'endcap_forward', 'endcap_backward')
)
def test_load_gean4_csv_set_weights_to_one(file, key):
detector = load_geant4_csv(file)['instrument'][key]
events = detector.bins.constituents['data'].data
sc.testing.assert_identical(
events, sc.ones(sizes=events.sizes, with_variances=True, unit='counts')
)


def test_load_geant4_csv_mantle_has_expected_coords(file):
# Only testing ranges that will not change in the future
mantle = load_geant4_csv(file)['instrument']['mantle']
assert_index_coord(mantle.coords['module'])
assert_index_coord(mantle.coords['segment'])
assert_index_coord(mantle.coords['counter'])
assert_index_coord(mantle.coords['wire'], values=set(range(1, 33)))
assert_index_coord(mantle.coords['strip'], values=set(range(1, 257)))
assert 'sector' not in mantle.coords

assert 'sector' not in mantle.bins.coords
assert 'tof' in mantle.bins.coords
assert 'wavelength' in mantle.bins.coords
assert 'position' in mantle.bins.coords


def test_load_geant4_csv_endcap_backward_has_expected_coords(file):
endcap = load_geant4_csv(file)['instrument']['endcap_backward']
assert_index_coord(endcap.coords['module'])
assert_index_coord(endcap.coords['segment'])
assert_index_coord(endcap.coords['counter'])
assert_index_coord(endcap.coords['wire'], values=set(range(1, 17)))
assert_index_coord(endcap.coords['strip'], values=set(range(1, 17)))
assert 'sector' not in endcap.coords

assert 'sector' not in endcap.bins.coords
assert 'tof' in endcap.bins.coords
assert 'wavelength' in endcap.bins.coords
assert 'position' in endcap.bins.coords


def test_load_geant4_csv_endcap_forward_has_expected_coords(file):
endcap = load_geant4_csv(file)['instrument']['endcap_forward']
assert_index_coord(endcap.coords['module'])
assert_index_coord(endcap.coords['segment'])
assert_index_coord(endcap.coords['counter'])
assert_index_coord(endcap.coords['wire'], values=set(range(1, 17)))
assert_index_coord(endcap.coords['strip'], values=set(range(1, 17)))
assert 'sector' not in endcap.coords

assert 'sector' not in endcap.bins.coords
assert 'tof' in endcap.bins.coords
assert 'wavelength' in endcap.bins.coords
assert 'position' in endcap.bins.coords


def test_load_geant4_csv_high_resolution_has_expected_coords(file):
hr = load_geant4_csv(file)['instrument']['high_resolution']
assert_index_coord(hr.coords['module'])
assert_index_coord(hr.coords['segment'])
assert_index_coord(hr.coords['counter'])
assert_index_coord(hr.coords['wire'], values=set(range(1, 17)))
assert_index_coord(hr.coords['strip'], values=set(range(1, 33)))
assert_index_coord(hr.coords['sector'], values=set(range(1, 5)))

assert 'tof' in hr.bins.coords
assert 'wavelength' in hr.bins.coords
assert 'position' in hr.bins.coords

0 comments on commit 7025e17

Please sign in to comment.