diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index 78f283b3..b663584c 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -10,7 +10,7 @@ """ from pathlib import Path -from typing import Union, Optional +from typing import Union, Optional, List import re import warnings import json @@ -1680,10 +1680,83 @@ def read_mearec(file: Union[str, Path]) -> Probe: return probe -def read_nwb(file): +def read_nwb(nwbfile: Union[Path, str]) -> List[Probe]: """ - Read probe position from an NWB file + Load ndx_probeinterface.Probe as + probeinterface.Probe objects from an NWB file + Parameters + ---------- + nwbfile : Path or str + The path to nwbfile + + Returns + ------- + probe : List[Probe] + List of Probe objects + """ + try: + import ndx_probeinterface + import pynwb + except ImportError: + raise ImportError("Missing `ndx_probeinterface` or `pynwb`") + + with pynwb.NWBHDF5IO(nwbfile, mode="r", load_namespaces=True) as io: + nwbf = io.read() + ndx_probes = [] + for device in nwbf.devices: + if isinstance(device, ndx_probeinterface.Probe): + ndx_probes.append(device) + if not ndx_probes: + core_probe = _from_nwb_ElectrodeTable(nwbf.electrodes[:]) + return [core_probe] + probes = [] + for ndx_probe in ndx_probes: + probes.append(ndx_probeinterface.to_probeinterface(ndx_probe)) + + return probes + + +def _from_nwb_ElectrodeTable(nwbf_electrodes): """ + Load NWB core ElectrodeTable as probeinterface.Probe object + Warning: makes some assumptions - raise NotImplementedError + Parameters + ---------- + nwbfile : Path or str + The path to nwbfile + + Returns + ------- + probeinterface_probe : Probe + """ + + # infer dimension by + # 1. checking for columns with names 'rel_x', 'rel_y', 'rel_z' + # 2. checking how many of these columns have elements that are not all zero + rel_present = [rel for rel in ["rel_x", "rel_y", "rel_z"] if rel in nwbf_electrodes.columns] + true_rel_present = [] + for rel in rel_present: + if (nwbf_electrodes[rel][:] == 0).all() == False: + true_rel_present.append(rel) + ndim = len(true_rel_present) + assert ndim >= 1, "Insufficient position information to generate Probe object." + + # no way to read units when only ElectrodeTable is present; just assume microns + unit = "um" + + # create Probe + probeinterface_probe = Probe(ndim=ndim, si_units=unit) + + # infer positions + n_contacts = len(nwbf_electrodes) + positions = np.zeros((n_contacts, ndim)) + for i in range(ndim): + positions[:, i] = nwbf_electrodes[true_rel_present[i]][:] + + # set contacts + probeinterface_probe.set_contacts( + positions=positions, + ) + return probeinterface_probe diff --git a/tests/test_io/test_nwb.py b/tests/test_io/test_nwb.py new file mode 100644 index 00000000..17fdb8c8 --- /dev/null +++ b/tests/test_io/test_nwb.py @@ -0,0 +1,15 @@ +from pathlib import Path +import numpy as np + +import pytest + +try: + import ndx_probeinterface + import pynwb +except ImportError: + raise ImportError("Missing `ndx_probeinterface` or `pynwb`") + +# data_path = Path(__file__).absolute().parent.parent / "data" / "nwb" + +def test_nwb(): + return NotImplemented