Skip to content

Commit

Permalink
allowed a dictionary of alterate dataset names to be given by the writer
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Jun 5, 2024
1 parent c2c8c0b commit 34f9b20
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
12 changes: 7 additions & 5 deletions src/pandablocks/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ class HDFWriter(Pipeline):
per file created.
"""

def __init__(self, file_names: Iterator[str]):
def __init__(self, file_names: Iterator[str], dataset_names: Dict[str, str]):
super().__init__()
self.file_names = file_names
self.hdf_file: Optional[h5py.File] = None
self.datasets: List[h5py.Dataset] = []
self.dataset_names = dataset_names
self.what_to_do = {
StartData: self.open_file,
list: self.write_frame,
Expand All @@ -101,7 +102,7 @@ def create_dataset(self, field: FieldCapture, raw: bool):
# No processor, datatype passed through
dtype = field.type
return self.hdf_file.create_dataset(
f"/{field.name}.{field.capture}",
f"/{self.dataset_names.get(field.name, field.name)}.{field.capture}",
dtype=dtype,
shape=(0,),
maxshape=(None,),
Expand Down Expand Up @@ -195,20 +196,21 @@ def scale_data(self, data: FrameData) -> List[np.ndarray]:


def create_default_pipeline(
file_names: Iterator[str], *additional_downstream_pipelines: Pipeline
file_names: Iterator[str], dataset_names: Dict[str, str], *additional_downstream_pipelines: Pipeline
) -> List[Pipeline]:
"""Create the default processing pipeline consisting of one `FrameProcessor` and
one `HDFWriter`. See `create_pipeline` for more details.
Args:
file_names: Iterator of file names. Must be full file paths. Will be called once
per file created. As required by `HDFWriter`.
dataset_names: A dictionary of dataset names to use for each field.
additional_downstream_pipelines: Any number of additional pipelines to add
downstream.
"""

return create_pipeline(
FrameProcessor(), HDFWriter(file_names), *additional_downstream_pipelines
FrameProcessor(), HDFWriter(file_names, dataset_names), *additional_downstream_pipelines
)


Expand Down Expand Up @@ -256,7 +258,7 @@ async def write_hdf_files(
"""
counter = 0
end_data = None
pipeline = create_default_pipeline(file_names)
pipeline = create_default_pipeline(file_names, {})
try:
async for data in client.data(scaled=False, flush_period=flush_period):
pipeline[0].queue.put_nowait(data)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self):

try:
pipeline = create_default_pipeline(
iter([Path(tmp_path / "1.h5")]), file_counter
iter([Path(tmp_path / "1.h5")]), {}, file_counter
)

pipeline[0].queue.put_nowait(
Expand Down

0 comments on commit 34f9b20

Please sign in to comment.