diff --git a/src/ophyd_async/panda/writers/hdf_writer.py b/src/ophyd_async/panda/writers/hdf_writer.py index cb62a71e95..fd7e56edb4 100644 --- a/src/ophyd_async/panda/writers/hdf_writer.py +++ b/src/ophyd_async/panda/writers/hdf_writer.py @@ -74,6 +74,8 @@ def _del_ctxt(): async def connect(self, sim=False) -> None: pvi_info = await pvi_get(self._prefix + ":PVI", self.ctxt) if not sim else {} + + # signals to connect, giving block name, signal name and datatype desired_signals = {} for block_name, block in self._to_capture.items(): if block_name not in desired_signals: @@ -82,11 +84,13 @@ async def connect(self, sim=False) -> None: desired_signals[block_name].append( [f"{signal_name}_capture", SimpleCapture] ) + # add signals from DataBlock using type hints if "hdf5" not in desired_signals: desired_signals["hdf5"] = [] for signal_name, hint in get_type_hints(self.hdf5).items(): dtype = hint.__args__[0] desired_signals["hdf5"].append([signal_name, dtype]) + # loop over desired signals and set for block_name, block_signals in desired_signals.items(): if block_name not in pvi_info: continue @@ -104,13 +108,17 @@ async def connect(self, sim=False) -> None: read_pv = write_pv if len(pvs) == 1 else pvs[1] pv_ctxt = self.ctxt.get(read_pv) if dtype is SimpleCapture: # capture record + # some :CAPTURE PVs have only 2 values, many have 9 if set(pv_ctxt.value.choices) == set(v.value for v in Capture): dtype = Capture signal = self.pvi_mapping[operations]( dtype, "pva://" + read_pv, "pva://" + write_pv ) setattr(block, signal_name, signal) - await block.connect() + for block_name in desired_signals.keys(): + block = getattr(self, block_name) + if block: + await block.connect(sim=sim) def __init__( self, @@ -206,11 +214,11 @@ async def collect_stream_docs(self, indices_written: int) -> AsyncIterator[Asset for doc in self._file.stream_resources(): ds_name = doc["resource_kwargs"]["name"] ds_block = doc["resource_kwargs"]["block"] - block = getattr(self, ds_block) - # capturing = getattr(self.hdf5, "capturing_" + ds_name, None) - capturing = getattr(block, f"{ds_name}_capture") - if capturing and await capturing.get_value() != Capture.No: - yield "stream_resource", doc + block = getattr(self, ds_block, None) + if block is not None: + capturing = getattr(block, f"{ds_name}_capture") + if capturing and await capturing.get_value() != Capture.No: + yield "stream_resource", doc for doc in self._file.stream_data(indices_written): yield "stream_datum", doc diff --git a/tests/panda/test_writer.py b/tests/panda/test_writer.py index fc1d8712d3..f2227eba08 100644 --- a/tests/panda/test_writer.py +++ b/tests/panda/test_writer.py @@ -5,26 +5,49 @@ from ophyd_async.core import ( DeviceCollector, + SimSignalBackend, StaticDirectoryProvider, set_and_wait_for_value, ) -from ophyd_async.epics.signal.signal import SignalR, epics_signal_rw -from ophyd_async.panda.writers.hdf_writer import PandaHDFWriter -from ophyd_async.panda.writers.panda_hdf import PandaHDF +from ophyd_async.epics.signal.signal import SignalR, SignalRW, epics_signal_rw +from ophyd_async.panda.writers import PandaHDFWriter @pytest.fixture async def sim_writer(tmp_path) -> PandaHDFWriter: dir_prov = StaticDirectoryProvider(str(tmp_path), "test") async with DeviceCollector(sim=True): - hdf = PandaHDF("TEST-PANDA") - writer = PandaHDFWriter(hdf, dir_prov, lambda: "test-panda") + writer = PandaHDFWriter("TEST-PANDA", dir_prov, lambda: "test-panda") + writer.hdf5.filepath = SignalRW( + SimSignalBackend(str, "TEST-PANDA:HDF5:FilePath") + ) + writer.hdf5.filename = SignalRW( + SimSignalBackend(str, "TEST-PANDA:HDF5:FileName") + ) + writer.hdf5.fullfilename = SignalRW( + SimSignalBackend(str, "TEST-PANDA:HDF5:FullFileName") + ) + writer.hdf5.numcapture = SignalRW( + SimSignalBackend(int, "TEST-PANDA:HDF5:NumCapture") + ) + writer.hdf5.capture = SignalRW(SimSignalBackend(int, "TEST-PANDA:HDF5:Capture")) + writer.hdf5.capturing = SignalRW( + SimSignalBackend(int, "TEST-PANDA:HDF5:Capturing") + ) + writer.hdf5.flushnow = SignalRW( + SimSignalBackend(int, "TEST-PANDA:HDF5:FlushNow") + ) + writer.hdf5.numwritten_rbv = SignalRW( + SimSignalBackend(int, "TEST-PANDA:HDF5:NumWritten_RBV") + ) + await writer.connect(sim=True) return writer @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") async def test_open_returns_descriptors(sim_writer): - description = await sim_writer.open() + await sim_writer.hdf5.capturing.set(1) + description = await sim_writer.open() # to make capturing status not time out assert isinstance(description, dict) for key, entry in description.items(): assert isinstance(key, str) @@ -36,23 +59,26 @@ async def test_open_returns_descriptors(sim_writer): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") async def test_open_close_sets_capture(sim_writer): - return_val = await sim_writer.open() + await sim_writer.hdf5.capturing.set(1) + return_val = await sim_writer.open() # to make capturing status not time out assert isinstance(return_val, dict) - capturing = await sim_writer.hdf.capture.get_value() - assert capturing is True + capture = await sim_writer.hdf5.capture.get_value() + assert capture is True + await sim_writer.hdf5.capturing.set(0) await sim_writer.close() - capturing = await sim_writer.hdf.capture.get_value() - assert capturing is False + capture = await sim_writer.hdf5.capture.get_value() + assert capture is False @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") async def test_open_sets_file_path(sim_writer, tmp_path): - path = await sim_writer.hdf.file_path.get_value() + path = await sim_writer.hdf5.filepath.get_value() assert path == "" + await sim_writer.hdf5.capturing.set(1) # to make capturing status not time out await sim_writer.open() - path = await sim_writer.hdf.file_path.get_value() + path = await sim_writer.hdf5.filepath.get_value() assert path == str(tmp_path) - name = await sim_writer.hdf.file_name.get_value() + name = await sim_writer.hdf5.filename.get_value() assert name == "test.h5" @@ -71,13 +97,10 @@ async def get_twentyfive(): @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") async def test_wait_for_index(sim_writer): - assert type(sim_writer.hdf.num_written) is SignalR - # usually num_written is a SignalR so can't be set from ophyd, + # usually numwritten_rbv is a SignalR so can't be set from ophyd, # overload with SignalRW for testing - sim_writer.hdf.num_written = epics_signal_rw(int, "TEST-PANDA:HDF5:NumWritten") - await sim_writer.hdf.num_written.connect(sim=True) - await set_and_wait_for_value(sim_writer.hdf.num_written, 25) - assert (await sim_writer.hdf.num_written.get_value()) == 25 + await set_and_wait_for_value(sim_writer.hdf5.numwritten_rbv, 25) + assert (await sim_writer.hdf5.numwritten_rbv.get_value()) == 25 await sim_writer.wait_for_index(25, timeout=1) with pytest.raises(TimeoutError): await sim_writer.wait_for_index(27, timeout=1)