diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 6ea4db9..0c7874a 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -47,6 +47,4 @@ jobs: - uses: pypa/gh-action-pypi-publish@release/v1 if: github.event_name == 'release' && github.event.action == 'published' with: - # Remember to tell (test-)pypi about this repo before publishing - # Remove this line to publish to PyPI - repository-url: https://test.pypi.org/legacy/ + packages-dir: ./dist/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index efcc345..6bdf544 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,8 +4,6 @@ on: workflow_dispatch: pull_request: push: - branches: - - main concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -34,36 +32,78 @@ jobs: pipx run nox -s pylint checks: + # pull requests are a duplicate of a branch push if within the same repo. + if: + github.event_name != 'pull_request' || + github.event.pull_request.head.repo.full_name != github.repository + name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} runs-on: ${{ matrix.runs-on }} needs: [pre-commit] strategy: fail-fast: false matrix: - python-version: ["3.8", "3.12"] - runs-on: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.8", "3.9", "3.10", "3.11"] + # runs-on: [ubuntu-latest, macos-latest, windows-latest] + runs-on: [ubuntu-latest] + + # include: + # - python-version: pypy-3.10 + # runs-on: ubuntu-latest + env: + TZ: America/New_York - include: - - python-version: pypy-3.10 - runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: actions/setup-python@v5 + - name: Set env vars + run: | + set -x + export REPOSITORY_NAME=${GITHUB_REPOSITORY#*/} # just the repo, as opposed to org/repo + echo "REPOSITORY_NAME=${REPOSITORY_NAME}" >> $GITHUB_ENV + + export PYTHONVER=$(echo ${{ matrix.python-version }} | sed 's/\.//g') + echo "PYTHONVER=${PYTHONVER}" >> $GITHUB_ENV + + export DATETIME_STRING=$(date +%Y%m%d%H%M%S) + echo "DATETIME_STRING=${DATETIME_STRING}" >> $GITHUB_ENV + + # - uses: actions/setup-python@v5 + # with: + # python-version: ${{ matrix.python-version }} + # allow-prereleases: true + + - name: Set up Python ${{ matrix.python-version }} with conda + uses: mamba-org/setup-micromamba@v1 with: - python-version: ${{ matrix.python-version }} - allow-prereleases: true + init-shell: bash + environment-name: ${{env.REPOSITORY_NAME}}-py${{matrix.python-version}} + create-args: >- + python=${{ matrix.python-version }} epics-base setuptools<67 - name: Install package - run: python -m pip install .[test] + run: | + set -vxeuo pipefail + which caput + python -m pip install .[test] - name: Test package run: >- python -m pytest -ra --cov --cov-report=xml --cov-report=term - --durations=20 + --durations=20 -m "(not hardware) and (not tiled)" -s -vv + + - name: Upload test artifacts + uses: actions/upload-artifact@v4 + with: + name: ${{env.REPOSITORY_NAME}}-py${{env.PYTHONVER}}-${{env.DATETIME_STRING}} + path: /tmp/srx-caproto-iocs/ + retention-days: 14 - name: Upload coverage report uses: codecov/codecov-action@v4.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac556dd..199f43e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,27 +33,27 @@ repos: - id: rst-inline-touching-normal - repo: https://github.com/pre-commit/mirrors-prettier - rev: "v3.1.0" + rev: "v4.0.0-alpha.8" hooks: - id: prettier types_or: [yaml, markdown, html, css, scss, javascript, json] args: [--prose-wrap=always] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.1.14" + rev: "v0.2.1" hooks: - id: ruff args: ["--fix", "--show-fixes"] - id: ruff-format - - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.8.0" - hooks: - - id: mypy - files: src|tests - args: [] - additional_dependencies: - - pytest + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: "v1.8.0" + # hooks: + # - id: mypy + # files: src|tests + # args: [] + # additional_dependencies: + # - pytest - repo: https://github.com/codespell-project/codespell rev: "v2.2.6" @@ -80,7 +80,7 @@ repos: additional_dependencies: ["validate-pyproject-schema-store[all]"] - repo: https://github.com/python-jsonschema/check-jsonschema - rev: "0.27.3" + rev: "0.28.0" hooks: - id: check-dependabot - id: check-github-workflows diff --git a/pyproject.toml b/pyproject.toml index e02b752..580497f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling", "hatch-vcs"] +requires = ["hatchling", "hatch-vcs", "setuptools>=61,<67"] build-backend = "hatchling.build" @@ -30,7 +30,14 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = [] +dependencies = [ + "caproto", + "h5py", + "numpy", + "ophyd", + "pyepics", # does not work with 'setuptools' version higher than v66.1.1 + "scikit-image[data]", +] [project.optional-dependencies] test = [ @@ -38,6 +45,10 @@ test = [ "pytest-cov >=3", ] dev = [ + "ipython", + "nexpy", + "pre-commit", + "pylint", "pytest >=6", "pytest-cov >=3", ] @@ -76,7 +87,12 @@ log_cli_level = "INFO" testpaths = [ "tests", ] - +markers = [ + "hardware: marks tests as requiring the hardware IOC to be available/running (deselect with '-m \"not hardware\"')", + "tiled: marks tests as requiring tiled", + "cloud_friendly: marks tests to be able to execute in the CI in the cloud", + "needs_epics_core: marks tests as requiring epics-core executables such as caget, caput, etc." +] [tool.coverage] run.source = ["srx_caproto_iocs"] @@ -108,7 +124,7 @@ src = ["src"] extend-select = [ "B", # flake8-bugbear "I", # isort - "ARG", # flake8-unused-arguments + # "ARG", # flake8-unused-arguments "C4", # flake8-comprehensions "EM", # flake8-errmsg "ICN", # flake8-import-conventions @@ -119,9 +135,9 @@ extend-select = [ "PT", # flake8-pytest-style "PTH", # flake8-use-pathlib "RET", # flake8-return - "RUF", # Ruff-specific + # "RUF", # Ruff-specific "SIM", # flake8-simplify - "T20", # flake8-print + # "T20", # flake8-print "UP", # pyupgrade "YTT", # flake8-2020 "EXE", # flake8-executable diff --git a/scripts/run-act.sh b/scripts/run-act.sh new file mode 100644 index 0000000..074a8c6 --- /dev/null +++ b/scripts/run-act.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +set -vxeuo pipefail + +PYTHON_VERSION="${1:-3.11}" + +act -W .github/workflows/ci.yml -j checks --matrix python-version:"${PYTHON_VERSION}" diff --git a/scripts/run-caproto-ioc.sh b/scripts/run-caproto-ioc.sh new file mode 100644 index 0000000..808b8dd --- /dev/null +++ b/scripts/run-caproto-ioc.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -vxeuo pipefail + +CAPROTO_IOC="${1:-srx_caproto_iocs.base}" +DEFAULT_PREFIX="BASE:{{Dev:Save1}}:" +CAPROTO_IOC_PREFIX="${2:-${DEFAULT_PREFIX}}" +# shellcheck source=/dev/null +if [ -f "/etc/profile.d/epics.sh" ]; then + . /etc/profile.d/epics.sh +fi + +export EPICS_CAS_AUTO_BEACON_ADDR_LIST="no" +export EPICS_CAS_BEACON_ADDR_LIST="${EPICS_CA_ADDR_LIST:-127.0.0.255}" + +python -m "${CAPROTO_IOC}" --prefix="${CAPROTO_IOC_PREFIX}" --list-pvs diff --git a/scripts/run-caproto-zebra-ioc.sh b/scripts/run-caproto-zebra-ioc.sh new file mode 100644 index 0000000..1d4118c --- /dev/null +++ b/scripts/run-caproto-zebra-ioc.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +set -vxeuo pipefail + +SCRIPT_DIR="$(dirname "$0")" + +bash "${SCRIPT_DIR}/run-caproto-ioc.sh" srx_caproto_iocs.zebra.caproto_ioc "XF:05IDD-ES:1{{Dev:Zebra2}}:" diff --git a/scripts/test-file-saving.sh b/scripts/test-file-saving.sh new file mode 100644 index 0000000..2159796 --- /dev/null +++ b/scripts/test-file-saving.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +set -euo pipefail + +# shellcheck source=/dev/null +if [ -f "/etc/profile.d/epics.sh" ]; then + . /etc/profile.d/epics.sh +fi + +num="${1:-50}" + +data_dir="/tmp/test/$(date +%Y/%m/%d)" +mkdir -v -p "${data_dir}" + +caput "BASE:{Dev:Save1}:write_dir" "${data_dir}" +caput "BASE:{Dev:Save1}:file_name" "saveme_{num:06d}_{uid}.h5" +caput "BASE:{Dev:Save1}:stage" 1 +caget -S "BASE:{Dev:Save1}:full_file_path" +for i in $(seq "$num"); do + echo "$i" + sleep 0.1 + caput "BASE:{Dev:Save1}:acquire" 1 +done + +caput "BASE:{Dev:Save1}:stage" 0 + +caget -S "BASE:{Dev:Save1}:full_file_path" + +exit 0 diff --git a/src/srx_caproto_iocs/base.py b/src/srx_caproto_iocs/base.py new file mode 100644 index 0000000..412d33c --- /dev/null +++ b/src/srx_caproto_iocs/base.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +import re +import textwrap +import threading +import time as ttime +import uuid +from enum import Enum +from pathlib import Path + +import skimage.data +from caproto import ChannelType +from caproto.ioc_examples.mini_beamline import no_reentry +from caproto.server import PVGroup, pvproperty, run, template_arg_parser +from ophyd import Component as Cpt +from ophyd import Device, EpicsSignal, EpicsSignalRO +from ophyd.status import SubscriptionStatus + +from .utils import now, save_hdf5_nd + + +class AcqStatuses(Enum): + """Enum class for acquisition statuses.""" + + IDLE = "idle" + ACQUIRING = "acquiring" + + +class StageStates(Enum): + """Enum class for stage states.""" + + UNSTAGED = "unstaged" + STAGED = "staged" + + +class CaprotoSaveIOC(PVGroup): + """Generic Caproto Save IOC""" + + write_dir = pvproperty( + value="/tmp", + doc="The directory to write data to. It support datetime formatting, e.g. '/tmp/det/%Y/%m/%d/'", + string_encoding="utf-8", + dtype=ChannelType.CHAR, + max_length=255, + ) + file_name = pvproperty( + value="test.h5", + doc="The file name of the file to write to. It support .format() based formatting, e.g. 'scan_{num:06d}.h5'", + string_encoding="utf-8", + dtype=ChannelType.CHAR, + max_length=255, + ) + full_file_path = pvproperty( + value="", + doc="Full path to the data file", + dtype=str, + read_only=True, + max_length=255, + ) + + # TODO: check non-negative value in @frame_num.putter. + frame_num = pvproperty(value=0, doc="Frame counter", dtype=int) + + stage = pvproperty( + value=StageStates.UNSTAGED.value, + enum_strings=[x.value for x in StageStates], + dtype=ChannelType.ENUM, + doc="Stage/unstage the device", + ) + + acquire = pvproperty( + value=AcqStatuses.IDLE.value, + enum_strings=[x.value for x in AcqStatuses], + dtype=ChannelType.ENUM, + doc="Acquire signal to save a dataset.", + ) + + def __init__(self, *args, update_rate=10.0, **kwargs): + super().__init__(*args, **kwargs) + + self._update_rate = update_rate + self._update_period = 1.0 / update_rate + + self._request_queue = None + self._response_queue = None + + queue = pvproperty(value=0, doc="A PV to facilitate threading-based queue") + + @queue.startup + async def queue(self, instance, async_lib): + """The startup behavior of the count property to set up threading queues.""" + # pylint: disable=unused-argument + self._request_queue = async_lib.ThreadsafeQueue(maxsize=1) + self._response_queue = async_lib.ThreadsafeQueue(maxsize=1) + + # Start a separate thread that consumes requests and sends responses. + thread = threading.Thread( + target=self.saver, + daemon=True, + kwargs={ + "request_queue": self._request_queue, + "response_queue": self._response_queue, + }, + ) + thread.start() + + async def _stage(self, instance, value): + """The stage method to perform preparation of a dataset to save the data.""" + if ( + instance.value in [True, StageStates.STAGED.value] + and value == StageStates.STAGED.value + ): + msg = "The device is already staged. Unstage it first." + print(msg) + return False + + if value == StageStates.STAGED.value: + # Steps: + # 1. Render 'write_dir' with datetime lib + # 2. Replace unsupported characters with underscores (sanitize). + # 3. Check if sanitized 'write_dir' exists + # 4. Render 'file_name' with .format(). + # 5. Replace unsupported characters with underscores. + + sanitizer = re.compile(pattern=r"[\":<>|\*\?\s]") + date = now(as_object=True) + write_dir = Path(sanitizer.sub("_", date.strftime(self.write_dir.value))) + if not write_dir.exists(): + msg = f"Path '{write_dir}' does not exist." + print(msg) + return False + + file_name = self.file_name.value + + uid = ( + str(uuid.uuid4()) if "{uid" in file_name or "{suid" in file_name else "" + ) + + full_file_path = write_dir / file_name.format( + num=self.frame_num.value, uid=uid, suid=uid[:8] + ) + full_file_path = sanitizer.sub("_", str(full_file_path)) + + print(f"{now()}: {full_file_path = }") + + await self.full_file_path.write(full_file_path) + + return True + + return False + + @stage.putter + async def stage(self, *args, **kwargs): + return await self._stage(*args, **kwargs) + + async def _get_current_dataset(self, frame): + """The method to return a desired dataset. + + See https://scikit-image.org/docs/stable/auto_examples/data/plot_3d.html + for details about the dataset returned by the base class' method. + """ + dataset = skimage.data.cells3d().sum(axis=1) + # This particular example dataset has 60 frames available, so we will cycle the slices for frame>=60. + return dataset[frame % dataset.shape[0], ...] + + @acquire.putter + @no_reentry + async def acquire(self, instance, value): + """The acquire method to perform an individual acquisition of a data point.""" + if ( + value != AcqStatuses.ACQUIRING.value + # or self.stage.value not in [True, StageStates.STAGED.value] + ): + return False + + if ( + instance.value in [True, AcqStatuses.ACQUIRING.value] + and value == AcqStatuses.ACQUIRING.value + ): + print( + f"The device is already acquiring. Please wait until the '{AcqStatuses.IDLE.value}' status." + ) + return True + + await self.acquire.write(AcqStatuses.ACQUIRING.value) + + # Delegate saving the resulting data to a blocking callback in a thread. + payload = { + "filename": self.full_file_path.value, + "data": await self._get_current_dataset(frame=self.frame_num.value), + "uid": str(uuid.uuid4()), + "timestamp": ttime.time(), + "frame_number": self.frame_num.value, + } + + await self._request_queue.async_put(payload) + response = await self._response_queue.async_get() + + if response["success"]: + # Increment the counter only on a successful saving of the file. + await self.frame_num.write(self.frame_num.value + 1) + + # await self.acquire.write(AcqStatuses.IDLE.value) + + return False + + @staticmethod + def saver(request_queue, response_queue): + """The saver callback for threading-based queueing.""" + while True: + received = request_queue.get() + filename = received["filename"] + data = received["data"] + frame_number = received["frame_number"] + try: + save_hdf5_nd(fname=filename, data=data, mode="x", group_path="enc1") + print( + f"{now()}: saved {frame_number=} {data.shape} data into:\n {filename}" + ) + + success = True + error_message = "" + except Exception as exc: # pylint: disable=broad-exception-caught + success = False + error_message = exc + print( + f"Cannot save file {filename!r} due to the following exception:\n{exc}" + ) + + response = {"success": success, "error_message": error_message} + response_queue.put(response) + + +class OphydDeviceWithCaprotoIOC(Device): + """An ophyd Device which works with the base caproto extension IOC.""" + + write_dir = Cpt(EpicsSignal, "write_dir", string=True) + file_name = Cpt(EpicsSignal, "file_name", string=True) + full_file_path = Cpt(EpicsSignalRO, "full_file_path", string=True) + frame_num = Cpt(EpicsSignal, "frame_num") + ioc_stage = Cpt(EpicsSignal, "stage", string=True) + acquire = Cpt(EpicsSignal, "acquire", string=True) + + def set(self, command): + """The set method with values for staging and acquiring.""" + + # print(f"{now()}: {command = }") + if command in [StageStates.STAGED.value, "stage"]: + expected_old_value = StageStates.UNSTAGED.value + expected_new_value = StageStates.STAGED.value + obj = self.ioc_stage + cmd = StageStates.STAGED.value + + if command in [StageStates.UNSTAGED.value, "unstage"]: + expected_old_value = StageStates.STAGED.value + expected_new_value = StageStates.UNSTAGED.value + obj = self.ioc_stage + cmd = StageStates.UNSTAGED.value + + if command in [AcqStatuses.ACQUIRING.value, "acquire"]: + expected_old_value = AcqStatuses.ACQUIRING.value + expected_new_value = AcqStatuses.IDLE.value + obj = self.acquire + cmd = AcqStatuses.ACQUIRING.value + + def cb(value, old_value, **kwargs): + # pylint: disable=unused-argument + # print(f"{now()}: {old_value} -> {value}") + if value == expected_new_value and old_value == expected_old_value: + return True + return False + + st = SubscriptionStatus(obj, callback=cb, run=False) + # print(f"{now()}: {cmd = }") + obj.put(cmd) + return st + + +def check_args(parser_, split_args_): + """Helper function to process caproto CLI args.""" + parsed_args = parser_.parse_args() + prefix = parsed_args.prefix + if not prefix: + parser_.error("The 'prefix' argument must be specified.") + + ioc_opts, run_opts = split_args_(parsed_args) + return ioc_opts, run_opts + + +if __name__ == "__main__": + parser, split_args = template_arg_parser( + default_prefix="", desc=textwrap.dedent(CaprotoSaveIOC.__doc__) + ) + ioc_options, run_options = check_args(parser, split_args) + ioc = CaprotoSaveIOC(**ioc_options) + run(ioc.pvdb, **run_options) diff --git a/src/srx_caproto_iocs/example/__init__.py b/src/srx_caproto_iocs/example/__init__.py new file mode 100644 index 0000000..5baeb5f --- /dev/null +++ b/src/srx_caproto_iocs/example/__init__.py @@ -0,0 +1 @@ +""""Example Caproto IOC code.""" diff --git a/src/srx_caproto_iocs/example/caproto_ioc.py b/src/srx_caproto_iocs/example/caproto_ioc.py new file mode 100644 index 0000000..d4b38a9 --- /dev/null +++ b/src/srx_caproto_iocs/example/caproto_ioc.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import textwrap + +from caproto import ChannelType +from caproto.server import PVGroup, pvproperty, run, template_arg_parser + +from ..base import check_args + + +class CaprotoStringIOC(PVGroup): + """Test channel types for strings.""" + + common_kwargs = {"max_length": 255, "string_encoding": "utf-8"} + + bare_string = pvproperty( + value="bare_string", doc="A test for a bare string", **common_kwargs + ) + implicit_string_type = pvproperty( + value="implicit_string_type", + doc="A test for an implicit string type", + report_as_string=True, + **common_kwargs, + ) + string_type = pvproperty( + value="string_type", + doc="A test for a string type", + dtype=str, + report_as_string=True, + **common_kwargs, + ) + string_type_enum = pvproperty( + value="string_type_enum", + doc="A test for a string type", + dtype=ChannelType.STRING, + **common_kwargs, + ) + char_type_as_string = pvproperty( + value="char_type_as_string", + doc="A test for a char type reported as string", + report_as_string=True, + dtype=ChannelType.CHAR, + **common_kwargs, + ) + char_type = pvproperty( + value="char_type", + doc="A test for a char type not reported as string", + dtype=ChannelType.CHAR, + **common_kwargs, + ) + + +if __name__ == "__main__": + parser, split_args = template_arg_parser( + default_prefix="", desc=textwrap.dedent(CaprotoStringIOC.__doc__) + ) + ioc_options, run_options = check_args(parser, split_args) + ioc = CaprotoStringIOC(**ioc_options) + + run(ioc.pvdb, **run_options) diff --git a/src/srx_caproto_iocs/example/ophyd.py b/src/srx_caproto_iocs/example/ophyd.py new file mode 100644 index 0000000..a44ac08 --- /dev/null +++ b/src/srx_caproto_iocs/example/ophyd.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from ophyd import Component as Cpt +from ophyd import Device, EpicsSignal + + +class OphydChannelTypes(Device): + """An ophyd Device which works with the CaprotoStringIOC caproto IOC.""" + + bare_string = Cpt(EpicsSignal, "bare_string", string=True) + implicit_string_type = Cpt(EpicsSignal, "implicit_string_type", string=True) + string_type = Cpt(EpicsSignal, "string_type", string=True) + string_type_enum = Cpt(EpicsSignal, "string_type_enum", string=True) + char_type_as_string = Cpt(EpicsSignal, "char_type_as_string", string=True) + char_type = Cpt(EpicsSignal, "char_type", string=True) diff --git a/src/srx_caproto_iocs/sis_scaler/__init__.py b/src/srx_caproto_iocs/sis_scaler/__init__.py new file mode 100644 index 0000000..479fe07 --- /dev/null +++ b/src/srx_caproto_iocs/sis_scaler/__init__.py @@ -0,0 +1 @@ +""""SIS scaler Caproto IOC code.""" diff --git a/src/srx_caproto_iocs/utils.py b/src/srx_caproto_iocs/utils.py new file mode 100644 index 0000000..737fe3d --- /dev/null +++ b/src/srx_caproto_iocs/utils.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import datetime +from pathlib import Path + +import h5py +import numpy as np + + +def now(as_object=False): + """A helper function to return ISO 8601 formatted datetime string.""" + _now = datetime.datetime.now() + if as_object: + return _now + return _now.isoformat() + + +def save_hdf5_zebra( + fname, + data, + dtype="float32", + mode="x", +): + """The function to export the 1-D data to an HDF5 file. + + Check https://docs.h5py.org/en/stable/high/file.html#opening-creating-files for modes: + + r Readonly, file must exist (default) + r+ Read/write, file must exist + w Create file, truncate if exists + w- or x Create file, fail if exists + a Read/write if exists, create otherwise + """ + with h5py.File(fname, mode, libver="latest") as h5file_desc: + for pvname, value in data.items(): + dataset = h5file_desc.create_dataset( + pvname, + data=value, + dtype=dtype, + ) + dataset.flush() + + +def save_hdf5_nd( + fname, + data, + group_name="/entry", + group_path="data/data", + dtype="float32", + mode="x", +): + """The function to export the N-D data to an HDF5 file (N>1). + + Check https://docs.h5py.org/en/stable/high/file.html#opening-creating-files for modes: + + r Readonly, file must exist (default) + r+ Read/write, file must exist + w Create file, truncate if exists + w- or x Create file, fail if exists + a Read/write if exists, create otherwise + """ + update_existing = Path(fname).is_file() + with h5py.File(fname, mode, libver="latest") as h5file_desc: + frame_shape = data.shape + if not update_existing: + group = h5file_desc.create_group(group_name) + dataset = group.create_dataset( + group_path, + data=np.full(fill_value=np.nan, shape=(1, *frame_shape)), + maxshape=(None, *frame_shape), + chunks=(1, *frame_shape), + dtype=dtype, + ) + frame_num = 0 + else: + dataset = h5file_desc[f"{group_name}/{group_path}"] + frame_num = dataset.shape[0] + + # https://docs.h5py.org/en/stable/swmr.html + h5file_desc.swmr_mode = True + + dataset.resize((frame_num + 1, *frame_shape)) + dataset[frame_num, ...] = data + dataset.flush() diff --git a/src/srx_caproto_iocs/zebra/__init__.py b/src/srx_caproto_iocs/zebra/__init__.py new file mode 100644 index 0000000..19c633b --- /dev/null +++ b/src/srx_caproto_iocs/zebra/__init__.py @@ -0,0 +1 @@ +""""Zebra Caproto IOC code.""" diff --git a/src/srx_caproto_iocs/zebra/caproto_ioc.py b/src/srx_caproto_iocs/zebra/caproto_ioc.py new file mode 100644 index 0000000..1f39217 --- /dev/null +++ b/src/srx_caproto_iocs/zebra/caproto_ioc.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import textwrap +from enum import Enum + +from caproto import ChannelType +from caproto.server import pvproperty, run, template_arg_parser + +from ..base import CaprotoSaveIOC, check_args +from ..utils import now, save_hdf5_zebra + +# def export_nano_zebra_data(zebra, filepath, fastaxis): +# j = 0 +# while zebra.pc.data_in_progress.get() == 1: +# print("Waiting for zebra...") +# ttime.sleep(0.1) +# j += 1 +# if j > 10: +# print("THE ZEBRA IS BEHAVING BADLY CARRYING ON") +# break + +# time_d = zebra.pc.data.time.get() +# enc1_d = zebra.pc.data.enc1.get() +# enc2_d = zebra.pc.data.enc2.get() +# enc3_d = zebra.pc.data.enc3.get() + +# px = zebra.pc.pulse_step.get() +# if fastaxis == 'NANOHOR': +# # Add half pixelsize to correct encoder +# enc1_d = enc1_d + (px / 2) +# elif fastaxis == 'NANOVER': +# # Add half pixelsize to correct encoder +# enc2_d = enc2_d + (px / 2) +# elif fastaxis == 'NANOZ': +# # Add half pixelsize to correct encoder +# enc3_d = enc3_d + (px / 2) + +# size = (len(time_d),) +# with h5py.File(filepath, "w") as f: +# dset0 = f.create_dataset("zebra_time", size, dtype="f") +# dset0[...] = np.array(time_d) +# dset1 = f.create_dataset("enc1", size, dtype="f") +# dset1[...] = np.array(enc1_d) +# dset2 = f.create_dataset("enc2", size, dtype="f") +# dset2[...] = np.array(enc2_d) +# dset3 = f.create_dataset("enc3", size, dtype="f") +# dset3[...] = np.array(enc3_d) + + +# class ZebraPositionCaptureData(Device): +# """ +# Data arrays for the Zebra position capture function and their metadata. +# """ +# # Data arrays +# ... +# enc1 = Cpt(EpicsSignal, "PC_ENC1") # XF:05IDD-ES:1{Dev:Zebra2}:PC_ENC1 +# enc2 = Cpt(EpicsSignal, "PC_ENC2") # XF:05IDD-ES:1{Dev:Zebra2}:PC_ENC2 +# enc3 = Cpt(EpicsSignal, "PC_ENC3") # XF:05IDD-ES:1{Dev:Zebra2}:PC_ENC3 +# time = Cpt(EpicsSignal, "PC_TIME") # XF:05IDD-ES:1{Dev:Zebra2}:PC_TIME +# ... + +# class ZebraPositionCapture(Device): +# """ +# Signals for the position capture function of the Zebra +# """ + +# # Configuration settings and status PVs +# ... +# pulse_step = Cpt(EpicsSignalWithRBV, "PC_PULSE_STEP") # XF:05IDD-ES:1{Dev:Zebra2}:PC_PULSE_STEP +# ... +# data_in_progress = Cpt(EpicsSignalRO, "ARRAY_ACQ") # XF:05IDD-ES:1{Dev:Zebra2}:ARRAY_ACQ +# ... +# data = Cpt(ZebraPositionCaptureData, "") + +# nanoZebra = SRXZebra( +# "XF:05IDD-ES:1{Dev:Zebra2}:", name="nanoZebra", +# read_attrs=["pc.data.enc1", "pc.data.enc2", "pc.data.enc3", "pc.data.time"], +# ) + +DEFAULT_MAX_LENGTH = 100_000 + + +class DevTypes(Enum): + """Enum class for devices.""" + + ZEBRA = "zebra" + SCALER = "scaler" + + +class ZebraSaveIOC(CaprotoSaveIOC): + """Zebra caproto save IOC.""" + + dev_type = pvproperty( + value=DevTypes.ZEBRA.value, + enum_strings=[x.value for x in DevTypes], + dtype=ChannelType.ENUM, + doc="Pick device type", + ) + + enc1 = pvproperty( + value=0, + doc="enc1 data", + max_length=DEFAULT_MAX_LENGTH, + ) + + enc2 = pvproperty( + value=0, + doc="enc2 data", + max_length=DEFAULT_MAX_LENGTH, + ) + + enc3 = pvproperty( + value=0, + doc="enc3 data", + max_length=DEFAULT_MAX_LENGTH, + ) + + zebra_time = pvproperty( + value=0, + doc="zebra time", + max_length=DEFAULT_MAX_LENGTH, + ) + + i0 = pvproperty( + value=0, + doc="i0 data", + max_length=DEFAULT_MAX_LENGTH, + ) + + im = pvproperty( + value=0, + doc="im data", + max_length=DEFAULT_MAX_LENGTH, + ) + + it = pvproperty( + value=0, + doc="it data", + max_length=DEFAULT_MAX_LENGTH, + ) + + sis_time = pvproperty( + value=0, + doc="sis time", + max_length=DEFAULT_MAX_LENGTH, + ) + + # def __init__(self, *args, external_pvs=None, **kwargs): + # """Init method. + + # external_pvs : dict + # a dictionary of external PVs with keys as human-readable names. + # """ + # super().__init__(*args, **kwargs) + # self._external_pvs = external_pvs + + async def _get_current_dataset( + self, *args, **kwargs + ): # , frame, external_pv="enc1"): + # client_context = Context() + # (pvobject,) = await client_context.get_pvs(self._external_pvs[external_pv]) + # print(f"{pvobject = }") + # # pvobject = pvobjects[0] + # ret = await pvobject.read() + + if self.dev_type == DevTypes.ZEBRA: + pvnames = ["enc1", "enc2", "enc3", "zebra_time"] + else: + pvnames = ["i0", "im", "it", "sis_time"] + + dataset = {} + for pvname in pvnames: + dataset[pvname] = getattr(self, pvname).value + + print(f"{now()}:\n{dataset}") + + return dataset + + @staticmethod + def saver(request_queue, response_queue): + """The saver callback for threading-based queueing.""" + while True: + received = request_queue.get() + filename = received["filename"] + data = received["data"] + # 'frame_number' is not used for this exporter. + frame_number = received["frame_number"] # noqa: F841 + try: + save_hdf5_zebra(fname=filename, data=data, mode="x") + print(f"{now()}: saved data into:\n {filename}") + + success = True + error_message = "" + except Exception as exc: # pylint: disable=broad-exception-caught + success = False + error_message = exc + print( + f"Cannot save file {filename!r} due to the following exception:\n{exc}" + ) + + response = {"success": success, "error_message": error_message} + response_queue.put(response) + + +if __name__ == "__main__": + parser, split_args = template_arg_parser( + default_prefix="", desc=textwrap.dedent(ZebraSaveIOC.__doc__) + ) + ioc_options, run_options = check_args(parser, split_args) + + # external_pv_prefix = ( + # ioc_options["prefix"].replace("{{", "{").replace("}}", "}") + # ) # "XF:05IDD-ES:1{Dev:Zebra2}:" + + # external_pvs = { + # "pulse_step": external_pv_prefix + "PC_PULSE_STEP", + # "data_in_progress": external_pv_prefix + "ARRAY_ACQ", + # "enc1": external_pv_prefix + "PC_ENC1", + # "enc2": external_pv_prefix + "PC_ENC2", + # "enc3": external_pv_prefix + "PC_ENC3", + # "enc4": external_pv_prefix + "PC_ENC4", + # "time": external_pv_prefix + "PC_TIME", + # } + + # ioc = ZebraSaveIOC(external_pvs=external_pvs, **ioc_options) + ioc = ZebraSaveIOC(**ioc_options) + run(ioc.pvdb, **run_options) diff --git a/src/srx_caproto_iocs/zebra/ophyd.py b/src/srx_caproto_iocs/zebra/ophyd.py new file mode 100644 index 0000000..2c17aa0 --- /dev/null +++ b/src/srx_caproto_iocs/zebra/ophyd.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from ..base import OphydDeviceWithCaprotoIOC + + +class ZebraWithCaprotoIOC(OphydDeviceWithCaprotoIOC): + """An ophyd Device which works with the Zebra caproto extension IOC.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a605f72 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import os +import socket +import string +import subprocess +import sys +import time as ttime + +import pytest + +from srx_caproto_iocs.base import OphydDeviceWithCaprotoIOC +from srx_caproto_iocs.example.ophyd import OphydChannelTypes + +CAPROTO_PV_PREFIX = "BASE:{{Dev:Save1}}:" +OPHYD_PV_PREFIX = CAPROTO_PV_PREFIX.replace("{{", "{").replace("}}", "}") + + +def get_epics_env(): + first_three = ".".join(socket.gethostbyname(socket.gethostname()).split(".")[:3]) + broadcast = f"{first_three}.255" + + print(f"{broadcast = }") + + # from pprint import pformat + # import netifaces + # interfaces = netifaces.interfaces() + # print(f"{interfaces = }") + # for interface in interfaces: + # addrs = netifaces.ifaddresses(interface) + # try: + # print(f"{interface = }: {pformat(addrs[netifaces.AF_INET])}") + # except Exception as e: + # print(f"{interface = }: exception:\n {e}") + + return { + "EPICS_CAS_BEACON_ADDR_LIST": os.getenv("EPICS_CA_ADDR_LIST", broadcast), + "EPICS_CAS_AUTO_BEACON_ADDR_LIST": "no", + } + + +def start_ioc_subprocess(ioc_name="srx_caproto_iocs.base", pv_prefix=CAPROTO_PV_PREFIX): + env = get_epics_env() + + command = f"{sys.executable} -m {ioc_name} --prefix={pv_prefix} --list-pvs" + print( + f"\nStarting caproto IOC in via a fixture using the following command:\n\n {command}\n" + ) + os.environ.update(env) + return subprocess.Popen( + command.split(), + start_new_session=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=False, + env=os.environ, + ) + + +@pytest.fixture(scope="session") +def base_caproto_ioc(wait=5): + p = start_ioc_subprocess( + ioc_name="srx_caproto_iocs.base", pv_prefix=CAPROTO_PV_PREFIX + ) + + print(f"Wait for {wait} seconds...") + ttime.sleep(wait) + + yield p + + p.terminate() + + std_out, std_err = p.communicate() + std_out = std_out.decode() + sep = "=" * 80 + print(f"STDOUT:\n{sep}\n{std_out}") + print(f"STDERR:\n{sep}\n{std_err}") + + +@pytest.fixture() +def base_ophyd_device(): + dev = OphydDeviceWithCaprotoIOC( + OPHYD_PV_PREFIX, name="ophyd_device_with_caproto_ioc" + ) + yield dev + dev.ioc_stage.put("unstaged") + + +@pytest.fixture(scope="session") +def caproto_ioc_channel_types(wait=5): + p = start_ioc_subprocess( + ioc_name="srx_caproto_iocs.example.caproto_ioc", pv_prefix=CAPROTO_PV_PREFIX + ) + + print(f"Wait for {wait} seconds...") + ttime.sleep(wait) + + yield p + + p.terminate() + + std_out, std_err = p.communicate() + std_out = std_out.decode() + sep = "=" * 80 + print(f"STDOUT:\n{sep}\n{std_out}") + print(f"STDERR:\n{sep}\n{std_err}") + + +@pytest.fixture() +def ophyd_channel_types(): + dev = OphydChannelTypes(OPHYD_PV_PREFIX, name="ophyd_channel_type") + letters = iter(string.ascii_letters) + for cpt in sorted(dev.component_names): + getattr(dev, cpt).put(next(letters)) + return dev diff --git a/tests/test_base_ophyd.py b/tests/test_base_ophyd.py new file mode 100644 index 0000000..7e514bd --- /dev/null +++ b/tests/test_base_ophyd.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import re +import shutil +import time as ttime +import uuid +from pathlib import Path + +import h5py +import pytest + +from srx_caproto_iocs.utils import now + + +@pytest.mark.cloud_friendly() +@pytest.mark.parametrize( + "date_template", ["%Y/%m/", "%Y/%m/%d", "mydir/%Y/%m/%d", "disguised_spaces_%c"] +) +def test_base_ophyd_templates( + base_caproto_ioc, base_ophyd_device, date_template, num_frames=50, remove=False +): + tmpdirname = f"/tmp/srx-caproto-iocs/{str(uuid.uuid4())[:2]}" + date = now(as_object=True) + write_dir_root = Path(tmpdirname) + dir_template = f"{write_dir_root}/{date_template}" + + # We pre-create the test directory in advance as the IOC is not supposed to create one. + # The assumption for the IOC is that the directory will exist before saving a file to that. + # We need to substitute the unsupported characters below for it to work, as the IOC will do + # the same in `full_file_path` before returning the value. + sanitizer = re.compile(pattern=r"[\":<>|\*\?\s]") + write_dir = Path(sanitizer.sub("_", date.strftime(dir_template))) + write_dir.mkdir(parents=True, exist_ok=True) + + file_template = "scan_{num:06d}_{uid}.hdf5" + + dev = base_ophyd_device + dev.write_dir.put(dir_template) + dev.file_name.put(file_template) + + dev.set("stage").wait() + + full_file_path = dev.full_file_path.get() + print(f"{full_file_path = }") + + for i in range(num_frames): + print(f"Collecting frame {i}...") + dev.set("acquire").wait() + ttime.sleep(0.1) + + dev.set("unstage").wait() + + assert full_file_path, "The returned 'full_file_path' did not change." + assert Path(full_file_path).is_file(), f"No such file '{full_file_path}'" + + with h5py.File(full_file_path, "r", swmr=True) as f: + dataset = f["/entry/data/data"] + assert dataset.shape == (num_frames, 256, 256) + + ttime.sleep(1.0) + + if remove: + shutil.rmtree(tmpdirname) diff --git a/tests/test_string_ioc.py b/tests/test_string_ioc.py new file mode 100644 index 0000000..af2e46c --- /dev/null +++ b/tests/test_string_ioc.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import re +import string +import subprocess + +import pytest + +LIMIT = 39 +STRING_39 = string.ascii_letters[:LIMIT] +STRING_LONGER = string.ascii_letters + + +@pytest.mark.cloud_friendly() +@pytest.mark.parametrize("value", [STRING_39, STRING_LONGER]) +def test_strings( + caproto_ioc_channel_types, + ophyd_channel_types, + value, +): + ophyd_channel_types.bare_string.put(value) + + if len(value) <= LIMIT: + ophyd_channel_types.implicit_string_type.put(value) + else: + with pytest.raises(ValueError, match="byte string too long"): + ophyd_channel_types.implicit_string_type.put(value) + + if len(value) <= LIMIT: + ophyd_channel_types.string_type.put(value) + else: + with pytest.raises(ValueError, match="byte string too long"): + ophyd_channel_types.string_type.put(value) + + if len(value) <= LIMIT: + ophyd_channel_types.string_type_enum.put(value) + else: + with pytest.raises(ValueError, match="byte string too long"): + ophyd_channel_types.string_type_enum.put(value) + + if len(value) <= LIMIT: + ophyd_channel_types.char_type_as_string.put(value) + else: + with pytest.raises(ValueError, match="byte string too long"): + ophyd_channel_types.char_type_as_string.put(value) + + ophyd_channel_types.char_type.put(value) + + +@pytest.mark.cloud_friendly() +@pytest.mark.needs_epics_core() +def test_cainfo(caproto_ioc_channel_types, ophyd_channel_types): + for cpt in sorted(ophyd_channel_types.component_names): + command = ["cainfo", getattr(ophyd_channel_types, cpt).pvname] + command_str = " ".join(command) + ret = subprocess.run( + command, + check=False, + capture_output=True, + ) + stdout = ret.stdout.decode() + print( + f"command: {command_str}\n {ret.returncode=}\n STDOUT:\n{ret.stdout.decode()}\n STDERR:\n{ret.stderr.decode()}\n" + ) + assert ret.returncode == 0 + if cpt in [ + "char_type_as_string", + "implicit_string_type", + "string_type", + "string_type_enum", + ]: + assert "Native data type: DBF_STRING" in stdout + else: + assert "Native data type: DBF_CHAR" in stdout + + +@pytest.mark.cloud_friendly() +@pytest.mark.needs_epics_core() +@pytest.mark.parametrize("value", [STRING_39, STRING_LONGER]) +def test_caput(caproto_ioc_channel_types, ophyd_channel_types, value): + option = "" + for cpt in sorted(ophyd_channel_types.component_names): + if cpt in [ + "char_type_as_string", + "implicit_string_type", + "string_type", + "string_type_enum", + ]: + option = "-s" + would_trim = True + else: + option = "-S" + would_trim = False + command = ["caput", option, getattr(ophyd_channel_types, cpt).pvname, value] + command_str = " ".join(command) + ret = subprocess.run( + command, + check=False, + capture_output=True, + ) + stdout = ret.stdout.decode() + print( + f"command: {command_str}\n {ret.returncode=}\n STDOUT:\n{stdout}\n STDERR:\n{ret.stderr.decode()}\n" + ) + assert ret.returncode == 0 + actual = re.search("New : (.*)", stdout).group(1).split()[-1].rstrip() + if not would_trim or len(value) == LIMIT: + assert actual == value + else: + assert len(actual) < len(value)