Skip to content

Commit

Permalink
Fix tests for hdf_writer
Browse files Browse the repository at this point in the history
  • Loading branch information
jsouter committed Dec 8, 2023
1 parent e2cf4f9 commit b4b5456
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
20 changes: 14 additions & 6 deletions src/ophyd_async/panda/writers/hdf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def _del_ctxt():

async def connect(self, sim=False) -> None:
pvi_info = await pvi_get(self._prefix + ":PVI", self.ctxt) if not sim else {}

# signals to connect, giving block name, signal name and datatype
desired_signals = {}
for block_name, block in self._to_capture.items():
if block_name not in desired_signals:
Expand All @@ -82,11 +84,13 @@ async def connect(self, sim=False) -> None:
desired_signals[block_name].append(
[f"{signal_name}_capture", SimpleCapture]
)
# add signals from DataBlock using type hints
if "hdf5" not in desired_signals:
desired_signals["hdf5"] = []
for signal_name, hint in get_type_hints(self.hdf5).items():
dtype = hint.__args__[0]
desired_signals["hdf5"].append([signal_name, dtype])
# loop over desired signals and set
for block_name, block_signals in desired_signals.items():
if block_name not in pvi_info:
continue
Expand All @@ -104,13 +108,17 @@ async def connect(self, sim=False) -> None:
read_pv = write_pv if len(pvs) == 1 else pvs[1]
pv_ctxt = self.ctxt.get(read_pv)
if dtype is SimpleCapture: # capture record
# some :CAPTURE PVs have only 2 values, many have 9
if set(pv_ctxt.value.choices) == set(v.value for v in Capture):
dtype = Capture
signal = self.pvi_mapping[operations](
dtype, "pva://" + read_pv, "pva://" + write_pv
)
setattr(block, signal_name, signal)
await block.connect()
for block_name in desired_signals.keys():
block = getattr(self, block_name)
if block:
await block.connect(sim=sim)

def __init__(
self,
Expand Down Expand Up @@ -206,11 +214,11 @@ async def collect_stream_docs(self, indices_written: int) -> AsyncIterator[Asset
for doc in self._file.stream_resources():
ds_name = doc["resource_kwargs"]["name"]
ds_block = doc["resource_kwargs"]["block"]
block = getattr(self, ds_block)
# capturing = getattr(self.hdf5, "capturing_" + ds_name, None)
capturing = getattr(block, f"{ds_name}_capture")
if capturing and await capturing.get_value() != Capture.No:
yield "stream_resource", doc
block = getattr(self, ds_block, None)
if block is not None:
capturing = getattr(block, f"{ds_name}_capture")
if capturing and await capturing.get_value() != Capture.No:
yield "stream_resource", doc
for doc in self._file.stream_data(indices_written):
yield "stream_datum", doc

Expand Down
63 changes: 43 additions & 20 deletions tests/panda/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,49 @@

from ophyd_async.core import (
DeviceCollector,
SimSignalBackend,
StaticDirectoryProvider,
set_and_wait_for_value,
)
from ophyd_async.epics.signal.signal import SignalR, epics_signal_rw
from ophyd_async.panda.writers.hdf_writer import PandaHDFWriter
from ophyd_async.panda.writers.panda_hdf import PandaHDF
from ophyd_async.epics.signal.signal import SignalR, SignalRW, epics_signal_rw
from ophyd_async.panda.writers import PandaHDFWriter


@pytest.fixture
async def sim_writer(tmp_path) -> PandaHDFWriter:
dir_prov = StaticDirectoryProvider(str(tmp_path), "test")
async with DeviceCollector(sim=True):
hdf = PandaHDF("TEST-PANDA")
writer = PandaHDFWriter(hdf, dir_prov, lambda: "test-panda")
writer = PandaHDFWriter("TEST-PANDA", dir_prov, lambda: "test-panda")
writer.hdf5.filepath = SignalRW(
SimSignalBackend(str, "TEST-PANDA:HDF5:FilePath")
)
writer.hdf5.filename = SignalRW(
SimSignalBackend(str, "TEST-PANDA:HDF5:FileName")
)
writer.hdf5.fullfilename = SignalRW(
SimSignalBackend(str, "TEST-PANDA:HDF5:FullFileName")
)
writer.hdf5.numcapture = SignalRW(
SimSignalBackend(int, "TEST-PANDA:HDF5:NumCapture")
)
writer.hdf5.capture = SignalRW(SimSignalBackend(int, "TEST-PANDA:HDF5:Capture"))
writer.hdf5.capturing = SignalRW(
SimSignalBackend(int, "TEST-PANDA:HDF5:Capturing")
)
writer.hdf5.flushnow = SignalRW(
SimSignalBackend(int, "TEST-PANDA:HDF5:FlushNow")
)
writer.hdf5.numwritten_rbv = SignalRW(
SimSignalBackend(int, "TEST-PANDA:HDF5:NumWritten_RBV")
)
await writer.connect(sim=True)
return writer


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_open_returns_descriptors(sim_writer):
description = await sim_writer.open()
await sim_writer.hdf5.capturing.set(1)
description = await sim_writer.open() # to make capturing status not time out
assert isinstance(description, dict)
for key, entry in description.items():
assert isinstance(key, str)
Expand All @@ -36,23 +59,26 @@ async def test_open_returns_descriptors(sim_writer):

@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_open_close_sets_capture(sim_writer):
return_val = await sim_writer.open()
await sim_writer.hdf5.capturing.set(1)
return_val = await sim_writer.open() # to make capturing status not time out
assert isinstance(return_val, dict)
capturing = await sim_writer.hdf.capture.get_value()
assert capturing is True
capture = await sim_writer.hdf5.capture.get_value()
assert capture is True
await sim_writer.hdf5.capturing.set(0)
await sim_writer.close()
capturing = await sim_writer.hdf.capture.get_value()
assert capturing is False
capture = await sim_writer.hdf5.capture.get_value()
assert capture is False


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_open_sets_file_path(sim_writer, tmp_path):
path = await sim_writer.hdf.file_path.get_value()
path = await sim_writer.hdf5.filepath.get_value()
assert path == ""
await sim_writer.hdf5.capturing.set(1) # to make capturing status not time out
await sim_writer.open()
path = await sim_writer.hdf.file_path.get_value()
path = await sim_writer.hdf5.filepath.get_value()
assert path == str(tmp_path)
name = await sim_writer.hdf.file_name.get_value()
name = await sim_writer.hdf5.filename.get_value()
assert name == "test.h5"


Expand All @@ -71,13 +97,10 @@ async def get_twentyfive():

@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_wait_for_index(sim_writer):
assert type(sim_writer.hdf.num_written) is SignalR
# usually num_written is a SignalR so can't be set from ophyd,
# usually numwritten_rbv is a SignalR so can't be set from ophyd,
# overload with SignalRW for testing
sim_writer.hdf.num_written = epics_signal_rw(int, "TEST-PANDA:HDF5:NumWritten")
await sim_writer.hdf.num_written.connect(sim=True)
await set_and_wait_for_value(sim_writer.hdf.num_written, 25)
assert (await sim_writer.hdf.num_written.get_value()) == 25
await set_and_wait_for_value(sim_writer.hdf5.numwritten_rbv, 25)
assert (await sim_writer.hdf5.numwritten_rbv.get_value()) == 25
await sim_writer.wait_for_index(25, timeout=1)
with pytest.raises(TimeoutError):
await sim_writer.wait_for_index(27, timeout=1)
Expand Down

0 comments on commit b4b5456

Please sign in to comment.