Skip to content

Commit

Permalink
Fix probeinterface converter shape keys
Browse files Browse the repository at this point in the history
  • Loading branch information
rly committed Jul 11, 2024
1 parent bcf792c commit efa76fe
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 54 deletions.
40 changes: 30 additions & 10 deletions src/pynwb/ndx_extracellular_channels/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,24 @@ def from_probeinterface(

def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinterface.Probe:
"""
Construct a probeinterface.Probe from a ndx_extracellular_channels.Probe
Construct a probeinterface.Probe from a ndx_extracellular_channels.Probe.
ndx_extracellular_channels.Probe.name -> probeinterface.Probe.name
ndx_extracellular_channels.Probe.identifier -> probeinterface.Probe.serial_number
ndx_extracellular_channels.Probe.probe_model.name -> probeinterface.Probe.model_name
ndx_extracellular_channels.Probe.probe_model.manufacturer -> probeinterface.Probe.manufacturer
ndx_extracellular_channels.Probe.probe_model.ndim -> probeinterface.Probe.ndim
ndx_extracellular_channels.Probe.probe_model.planar_contour_in_um -> probeinterface.Probe.probe_planar_contour
ndx_extracellular_channels.Probe.probe_model.contacts_table["relative_position_in_mm"] ->
probeinterface.Probe.contact_positions
ndx_extracellular_channels.Probe.probe_model.contacts_table["shape"] -> probeinterface.Probe.contact_shapes
ndx_extracellular_channels.Probe.probe_model.contacts_table["contact_id"] -> probeinterface.Probe.contact_ids
ndx_extracellular_channels.Probe.probe_model.contacts_table["device_channel"] ->
probeinterface.Probe.device_channel_indices
ndx_extracellular_channels.Probe.probe_model.contacts_table["shank_id"] -> probeinterface.Probe.shank_ids
ndx_extracellular_channels.Probe.probe_model.contacts_table["plane_axes"] -> probeinterface.Probe.contact_plane_axes
ndx_extracellular_channels.Probe.probe_model.contacts_table["radius_in_um"] -> probeinterface.Probe.contact_shapes["radius"]
Parameters
----------
Expand All @@ -89,12 +106,11 @@ def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinter
shapes = []

contact_ids = None
shape_params = None
shank_ids = None
plane_axes = None
device_channel_indices = None

possible_shape_keys = ["radius", "width", "height"]
possible_shape_keys = ["radius_in_um", "width_in_um", "height_in_um"]
contacts_table = ndx_probe.probe_model.contacts_table

positions.append(contacts_table["relative_position_in_mm"][:])
Expand All @@ -115,11 +131,6 @@ def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinter
if shank_ids is None:
shank_ids = []
shank_ids.append(contacts_table["shank_id"][:])
for possible_shape_key in possible_shape_keys:
if possible_shape_key in contacts_table.colnames:
if shape_params is None:
shape_params = []
shape_params.append([{possible_shape_key: val} for val in contacts_table[possible_shape_key][:]])

positions = [item for sublist in positions for item in sublist]
shapes = [item for sublist in shapes for item in sublist]
Expand All @@ -128,13 +139,22 @@ def to_probeinterface(ndx_probe: ndx_extracellular_channels.Probe) -> probeinter
contact_ids = [item for sublist in contact_ids for item in sublist]
if plane_axes is not None:
plane_axes = [item for sublist in plane_axes for item in sublist]
if shape_params is not None:
shape_params = [item for sublist in shape_params for item in sublist]
if shank_ids is not None:
shank_ids = [item for sublist in shank_ids for item in sublist]
if device_channel_indices is not None:
device_channel_indices = [item for sublist in device_channel_indices for item in sublist]

# if there are multiple shape keys, e.g., radius, width, and height
# we need to create a list of dicts, one for each contact
shape_params = [dict() for _ in range(len(contacts_table))]
for i in range(len(contacts_table)):
for possible_shape_key in possible_shape_keys:
if possible_shape_key in contacts_table.colnames:
new_key = possible_shape_key.replace("_in_um", "")
shape_params[i][new_key] = contacts_table[possible_shape_key][i]

print(shape_params)

probeinterface_probe = probeinterface.Probe(
ndim=ndx_probe.probe_model.ndim,
si_units="um",
Expand Down
241 changes: 197 additions & 44 deletions src/pynwb/tests/test_example_usage_probeinterface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import ndx_extracellular_channels
import numpy as np
import numpy.testing as npt
import probeinterface
import pynwb
import uuid
Expand Down Expand Up @@ -31,7 +32,7 @@ def test_from_probeinterface():
polygon = [(-20.0, -30.0), (20.0, -110.0), (60.0, -30.0), (60.0, 190.0), (-20.0, 190.0)]
probe0.set_planar_contour(polygon)

probe1 = probeinterface.generate_dummy_probe(elec_shapes="circle")
probe1 = probeinterface.generate_dummy_probe(elec_shapes="circle") # no name set
probe1.serial_number = "1000"
probe1.model_name = "Dummy Neuropixels 1.0"
probe1.manufacturer = "IMEC"
Expand All @@ -45,6 +46,7 @@ def test_from_probeinterface():
probe2.move([500, -90])

probe3 = probeinterface.generate_dummy_probe(elec_shapes="circle")
probe3.name = "probe3"
probe3.serial_number = "1002"
probe3.model_name = "Dummy Neuropixels 3.0"
probe3.manufacturer = "IMEC"
Expand All @@ -61,7 +63,8 @@ def test_from_probeinterface():
ndx_probes.extend(model0)
model1 = ndx_extracellular_channels.from_probeinterface(probe1, name="probe1") # override name of probe
ndx_probes.extend(model1)
group_probes = ndx_extracellular_channels.from_probeinterface(probegroup, name=[None, "probe3"])
# override name of probe3
group_probes = ndx_extracellular_channels.from_probeinterface(probegroup, name=[None, "renamed_probe3"])
ndx_probes.extend(group_probes)

nwbfile = pynwb.NWBFile(
Expand All @@ -79,13 +82,13 @@ def test_from_probeinterface():
io.write(nwbfile)

# read the file and check the content
with pynwb.NWBHDF5IO("test_probeinterface.nwb", "r", load_namespaces=True) as io:
with pynwb.NWBHDF5IO("test_probeinterface.nwb", "r") as io:
nwbfile = io.read()
assert set(nwbfile.devices.keys()) == {
"probe0",
"probe1",
"probe2",
"probe3",
"renamed_probe3",
"a1x32-edge-5mm-20-177_H32",
"Dummy Neuropixels 1.0",
"Dummy Neuropixels 2.0",
Expand All @@ -96,7 +99,7 @@ def test_from_probeinterface():
assert isinstance(nwbfile.devices["probe0"], ndx_extracellular_channels.Probe)
assert isinstance(nwbfile.devices["probe1"], ndx_extracellular_channels.Probe)
assert isinstance(nwbfile.devices["probe2"], ndx_extracellular_channels.Probe)
assert isinstance(nwbfile.devices["probe3"], ndx_extracellular_channels.Probe)
assert isinstance(nwbfile.devices["renamed_probe3"], ndx_extracellular_channels.Probe)
assert isinstance(nwbfile.devices["a1x32-edge-5mm-20-177_H32"], ndx_extracellular_channels.ProbeModel)
assert isinstance(nwbfile.devices["Dummy Neuropixels 1.0"], ndx_extracellular_channels.ProbeModel)
assert isinstance(nwbfile.devices["Dummy Neuropixels 2.0"], ndx_extracellular_channels.ProbeModel)
Expand All @@ -107,64 +110,214 @@ def test_from_probeinterface():
assert nwbfile.devices["probe0"].probe_model.name == "a1x32-edge-5mm-20-177_H32"
assert nwbfile.devices["probe0"].probe_model.manufacturer == "Neuronexus"
assert nwbfile.devices["probe0"].probe_model.ndim == 2
assert np.all(nwbfile.devices["probe0"].probe_model.planar_contour_in_um == polygon)
assert np.allclose(nwbfile.devices["probe0"].probe_model.contacts_table.relative_position_in_mm, positions)
assert np.all(nwbfile.devices["probe0"].probe_model.contacts_table["shape"].data[:] == "circle")
assert np.all(nwbfile.devices["probe0"].probe_model.contacts_table["radius_in_um"].data[:] == 5.0)
npt.assert_array_equal(nwbfile.devices["probe0"].probe_model.planar_contour_in_um, polygon)
npt.assert_allclose(nwbfile.devices["probe0"].probe_model.contacts_table.relative_position_in_mm, positions)
npt.assert_array_equal(nwbfile.devices["probe0"].probe_model.contacts_table["shape"].data[:], "circle")
npt.assert_array_equal(nwbfile.devices["probe0"].probe_model.contacts_table["radius_in_um"].data[:], 5.0)

assert nwbfile.devices["probe1"].name == "probe1"
assert nwbfile.devices["probe1"].identifier == "1000"
assert nwbfile.devices["probe1"].probe_model.name == "Dummy Neuropixels 1.0"
assert nwbfile.devices["probe1"].probe_model.manufacturer == "IMEC"
assert nwbfile.devices["probe1"].probe_model.ndim == 2
assert np.allclose(nwbfile.devices["probe1"].probe_model.planar_contour_in_um, probe1.probe_planar_contour)
assert np.allclose(
npt.assert_allclose(nwbfile.devices["probe1"].probe_model.planar_contour_in_um, probe1.probe_planar_contour)
npt.assert_allclose(
nwbfile.devices["probe1"].probe_model.contacts_table.relative_position_in_mm, probe1.contact_positions
)
assert np.all(nwbfile.devices["probe1"].probe_model.contacts_table["shape"].data[:] == "circle")
assert np.all(
nwbfile.devices["probe1"].probe_model.contacts_table["radius_in_um"].data[:] == probe1.to_numpy()["radius"]
npt.assert_array_equal(nwbfile.devices["probe1"].probe_model.contacts_table["shape"].data[:], "circle")
npt.assert_array_equal(
nwbfile.devices["probe1"].probe_model.contacts_table["radius_in_um"].data[:], probe1.to_numpy()["radius"]
)

assert nwbfile.devices["probe2"].name == "probe2"
assert nwbfile.devices["probe2"].identifier == "1001"
assert nwbfile.devices["probe2"].probe_model.name == "Dummy Neuropixels 2.0"
assert nwbfile.devices["probe2"].probe_model.manufacturer == "IMEC"
assert nwbfile.devices["probe2"].probe_model.ndim == 2
assert np.allclose(nwbfile.devices["probe2"].probe_model.planar_contour_in_um, probe2.probe_planar_contour)
assert np.allclose(
npt.assert_allclose(nwbfile.devices["probe2"].probe_model.planar_contour_in_um, probe2.probe_planar_contour)
npt.assert_allclose(
nwbfile.devices["probe2"].probe_model.contacts_table.relative_position_in_mm, probe2.contact_positions
)
assert np.all(nwbfile.devices["probe2"].probe_model.contacts_table["shape"].data[:] == "square")
assert np.all(
nwbfile.devices["probe2"].probe_model.contacts_table["width_in_um"].data[:] == probe2.to_numpy()["width"]
npt.assert_array_equal(nwbfile.devices["probe2"].probe_model.contacts_table["shape"].data[:], "square")
npt.assert_array_equal(
nwbfile.devices["probe2"].probe_model.contacts_table["width_in_um"].data[:], probe2.to_numpy()["width"]
)

assert nwbfile.devices["probe3"].name == "probe3"
assert nwbfile.devices["probe3"].identifier == "1002"
assert nwbfile.devices["probe3"].probe_model.name == "Dummy Neuropixels 3.0"
assert nwbfile.devices["probe3"].probe_model.manufacturer == "IMEC"
assert nwbfile.devices["probe3"].probe_model.ndim == 2
assert np.allclose(nwbfile.devices["probe3"].probe_model.planar_contour_in_um, probe3.probe_planar_contour)
assert np.allclose(
nwbfile.devices["probe3"].probe_model.contacts_table.relative_position_in_mm, probe3.contact_positions
assert nwbfile.devices["renamed_probe3"].name == "renamed_probe3"
assert nwbfile.devices["renamed_probe3"].identifier == "1002"
assert nwbfile.devices["renamed_probe3"].probe_model.name == "Dummy Neuropixels 3.0"
assert nwbfile.devices["renamed_probe3"].probe_model.manufacturer == "IMEC"
assert nwbfile.devices["renamed_probe3"].probe_model.ndim == 2
npt.assert_allclose(
nwbfile.devices["renamed_probe3"].probe_model.planar_contour_in_um, probe3.probe_planar_contour
)
assert np.all(nwbfile.devices["probe3"].probe_model.contacts_table["shape"].data[:] == "circle")
assert np.all(
nwbfile.devices["probe3"].probe_model.contacts_table["radius_in_um"].data[:] == probe3.to_numpy()["radius"]
npt.assert_allclose(
nwbfile.devices["renamed_probe3"].probe_model.contacts_table.relative_position_in_mm,
probe3.contact_positions,
)
npt.assert_array_equal(nwbfile.devices["renamed_probe3"].probe_model.contacts_table["shape"].data[:], "circle")
npt.assert_array_equal(
nwbfile.devices["renamed_probe3"].probe_model.contacts_table["radius_in_um"].data[:],
probe3.to_numpy()["radius"]
)


def test_to_probeinterface():

# create a NWB file with a few probes
nwbfile = pynwb.NWBFile(
session_description="A description of my session",
identifier=str(uuid.uuid4()),
session_start_time=datetime.datetime.now(datetime.timezone.utc),
)

# create a probe model
probe_model0 = ndx_extracellular_channels.ProbeModel(
name="a1x32-edge-5mm-20-177_H32",
model="a1x32-edge-5mm-20-177_H32",
manufacturer="Neuronexus",
ndim=2,
planar_contour_in_um=[(-20.0, -30.0), (20.0, -110.0), (60.0, -30.0), (60.0, 190.0), (-20.0, 190.0)],
contacts_table=ndx_extracellular_channels.ContactsTable(
name="contacts_table",
description="a table with electrode contacts",
columns=[
pynwb.core.VectorData(
name="relative_position_in_mm",
description="the relative position of the contact in mm",
data=[
(0.0, 0.0),
(0.0, 20.0),
(0.0, 40.0),
(0.0, 60.0),
(0.0, 80.0),
(0.0, 100.0),
(0.0, 120.0),
(0.0, 140.0),
(20.0, 0.0),
(20.0, 20.0),
(20.0, 40.0),
(20.0, 60.0),
(20.0, 80.0),
(20.0, 100.0),
(20.0, 120.0),
(20.0, 140.0),
(40.0, 0.0),
(40.0, 20.0),
(40.0, 40.0),
(40.0, 60.0),
(40.0, 80.0),
(40.0, 100.0),
(40.0, 120.0),
(40.0, 140.0),
],
),
pynwb.core.VectorData(
name="shape",
description="the shape of the contact",
data=["circle"] * 24,
),
pynwb.core.VectorData(
name="radius_in_um",
description="the radius of the contact in um",
data=[5.0] * 24,
),
],
),
)

# create a probe
probe0 = ndx_extracellular_channels.Probe(
name="probe0",
identifier="0123",
probe_model=probe_model0,
)

# for device in nwbfile.devices.values():
# print("-------------------")
# print(device)
# if isinstance(device, ndx_extracellular_channels.ProbeModel):
# print(device.name)
# print(device.manufacturer)
# print(device.ndim)
# print(device.planar_contour_in_um)
# print(device.contacts_table.to_dataframe())
# if isinstance(device, ndx_extracellular_channels.Probe):
# pi_probe = ndx_extracellular_channels.to_probeinterface(device)
# print(pi_probe)

# TODO add more tests for other probeinterface IO functions
pi_probe0 = ndx_extracellular_channels.to_probeinterface(probe0)
assert pi_probe0.ndim == 2
assert pi_probe0.si_units == "um"
assert pi_probe0.name == "probe0"
assert pi_probe0.serial_number == "0123"
assert pi_probe0.model_name == "a1x32-edge-5mm-20-177_H32"
assert pi_probe0.manufacturer == "Neuronexus"
npt.assert_array_equal(pi_probe0.contact_positions, probe_model0.contacts_table.relative_position_in_mm)
npt.assert_array_equal(pi_probe0.contact_shapes, "circle")
npt.assert_array_equal(pi_probe0.to_numpy()["radius"], 5.0)

ct2 = ndx_extracellular_channels.ContactsTable(
description="Test contacts table",
)

# for testing, mix and match different shapes. np.nan means the radius/width/height does not apply
ct2.add_row(
relative_position_in_mm=[10.0, 10.0],
shape="circle",
contact_id="C1",
shank_id="shank0",
plane_axes=[[0.0, 1.0], [1.0, 0.0]], # TODO make realistic
radius_in_um=10.0,
width_in_um=np.nan,
height_in_um=np.nan,
device_channel=1,
)
ct2.add_row(
relative_position_in_mm=[20.0, 10.0],
shape="square",
contact_id="C2",
shank_id="shank0",
plane_axes=[[0.0, 1.0], [1.0, 0.0]], # TODO make realistic
radius_in_um=np.nan,
width_in_um=10.0,
height_in_um=10.0,
device_channel=2,
)
probe_model1 = ndx_extracellular_channels.ProbeModel(
name="Neuropixels 1.0",
description="A neuropixels probe",
model="Neuropixels 1.0",
manufacturer="IMEC",
planar_contour_in_um=[[-10.0, -10.0], [10.0, -10.0], [10.0, 10.0], [-10.0, 10.0]],
contacts_table=ct2,
)

# create a probe
probe1 = ndx_extracellular_channels.Probe(
name="probe1",
identifier="7890",
probe_model=probe_model1,
)

pi_probe1 = ndx_extracellular_channels.to_probeinterface(probe1)
assert pi_probe1.ndim == 2
assert pi_probe1.si_units == "um"
assert pi_probe1.name == "probe1"
assert pi_probe1.serial_number == "7890"
assert pi_probe1.model_name == "Neuropixels 1.0"
assert pi_probe1.manufacturer == "IMEC"
npt.assert_array_equal(pi_probe1.contact_positions, probe_model1.contacts_table.relative_position_in_mm)
npt.assert_array_equal(pi_probe1.contact_shapes, ["circle", "square"])
npt.assert_array_equal(pi_probe1.to_numpy()["radius"], [10.0, np.nan])
npt.assert_array_equal(pi_probe1.to_numpy()["width"], [np.nan, 10.0])
npt.assert_array_equal(pi_probe1.to_numpy()["height"], [np.nan, 10.0])

# add Probe as NWB Devices
nwbfile.add_device(probe_model0)
nwbfile.add_device(probe0)

with pynwb.NWBHDF5IO("test_probeinterface.nwb", "w") as io:
io.write(nwbfile)

# read the file and test whether the read probe can be converted back to probeinterface correctly
with pynwb.NWBHDF5IO("test_probeinterface.nwb", "r") as io:
nwbfile = io.read()
read_probe = nwbfile.devices["probe0"]
pi_probe = ndx_extracellular_channels.to_probeinterface(read_probe)
assert pi_probe.ndim == 2
assert pi_probe.si_units == "um"
assert pi_probe.name == "probe0"
assert pi_probe.serial_number == "0123"
assert pi_probe.model_name == "a1x32-edge-5mm-20-177_H32"
assert pi_probe.manufacturer == "Neuronexus"
npt.assert_array_equal(pi_probe.contact_positions, probe_model0.contacts_table.relative_position_in_mm)
npt.assert_array_equal(pi_probe.to_numpy()["radius"], 5.0)
npt.assert_array_equal(pi_probe.contact_shapes, "circle")

0 comments on commit efa76fe

Please sign in to comment.