From 34f9b202d8ead26cf29f454f5fda8b2431e3baa4 Mon Sep 17 00:00:00 2001 From: Eva Date: Wed, 5 Jun 2024 16:32:41 +0100 Subject: [PATCH] allowed a dictionary of alterate dataset names to be given by the writer --- src/pandablocks/hdf.py | 12 +++++++----- tests/test_hdf.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/pandablocks/hdf.py b/src/pandablocks/hdf.py index 00d47b73..1ab31270 100644 --- a/src/pandablocks/hdf.py +++ b/src/pandablocks/hdf.py @@ -80,11 +80,12 @@ class HDFWriter(Pipeline): per file created. """ - def __init__(self, file_names: Iterator[str]): + def __init__(self, file_names: Iterator[str], dataset_names: Dict[str, str]): super().__init__() self.file_names = file_names self.hdf_file: Optional[h5py.File] = None self.datasets: List[h5py.Dataset] = [] + self.dataset_names = dataset_names self.what_to_do = { StartData: self.open_file, list: self.write_frame, @@ -101,7 +102,7 @@ def create_dataset(self, field: FieldCapture, raw: bool): # No processor, datatype passed through dtype = field.type return self.hdf_file.create_dataset( - f"/{field.name}.{field.capture}", + f"/{self.dataset_names.get(field.name, field.name)}.{field.capture}", dtype=dtype, shape=(0,), maxshape=(None,), @@ -195,7 +196,7 @@ def scale_data(self, data: FrameData) -> List[np.ndarray]: def create_default_pipeline( - file_names: Iterator[str], *additional_downstream_pipelines: Pipeline + file_names: Iterator[str], dataset_names: Dict[str, str], *additional_downstream_pipelines: Pipeline ) -> List[Pipeline]: """Create the default processing pipeline consisting of one `FrameProcessor` and one `HDFWriter`. See `create_pipeline` for more details. @@ -203,12 +204,13 @@ def create_default_pipeline( Args: file_names: Iterator of file names. Must be full file paths. Will be called once per file created. As required by `HDFWriter`. + dataset_names: A dictionary of dataset names to use for each field. additional_downstream_pipelines: Any number of additional pipelines to add downstream. """ return create_pipeline( - FrameProcessor(), HDFWriter(file_names), *additional_downstream_pipelines + FrameProcessor(), HDFWriter(file_names, dataset_names), *additional_downstream_pipelines ) @@ -256,7 +258,7 @@ async def write_hdf_files( """ counter = 0 end_data = None - pipeline = create_default_pipeline(file_names) + pipeline = create_default_pipeline(file_names, {}) try: async for data in client.data(scaled=False, flush_period=flush_period): pipeline[0].queue.put_nowait(data) diff --git a/tests/test_hdf.py b/tests/test_hdf.py index 362e8da4..5ff7a2af 100644 --- a/tests/test_hdf.py +++ b/tests/test_hdf.py @@ -21,7 +21,7 @@ def __init__(self): try: pipeline = create_default_pipeline( - iter([Path(tmp_path / "1.h5")]), file_counter + iter([Path(tmp_path / "1.h5")]), {}, file_counter ) pipeline[0].queue.put_nowait(