diff --git a/socs/agents/lakeshore240/agent.py b/socs/agents/lakeshore240/agent.py index 1ae536fee..e5462a3c6 100644 --- a/socs/agents/lakeshore240/agent.py +++ b/socs/agents/lakeshore240/agent.py @@ -1,21 +1,18 @@ import argparse -import dataclasses import os import queue import time import traceback import warnings -from dataclasses import dataclass, fields -from typing import ( - Any, Dict, Generator, Optional, Tuple, Type, Callable, get_args, get_origin, - Union -) +from dataclasses import dataclass +from typing import ( Any, Dict, Optional, ) import txaio # type: ignore from ocs import ocs_agent, site_config from ocs.ocs_twisted import Pacemaker from socs.Lakeshore.Lakeshore240 import Module +from socs.util import BaseAction, register_task_from_action, OcsOpReturnType txaio.use_twisted() @@ -25,215 +22,91 @@ log = txaio.make_logger() # pylint: disable=E1101 -ActionReturnType = Optional[Dict[str, Any]] -OcsOpReturnType = Tuple[bool, str] -OcsInlineCallbackReturnType = Generator[Any, Any, OcsOpReturnType] +class LS240Action(BaseAction): + def process(self, module: Module) -> None: + raise NotImplementedError +@dataclass +class UploadCalCurve(LS240Action): + """upload_cal_curve(channel, filename) -class Actions: - "Namespace to hold action classes for the Lakeshore240 agent." + **Task** - Upload a calibration curve to a channel. - @dataclass - class BaseAction: - "Base class for all actions." - - def __post_init__(self) -> None: - self.processed: bool = False - self.success: bool = False - self.traceback: Optional[str] = None - self.result: ActionReturnType = None - - def resolve_action( - self, - success: bool, - traceback: Optional[str] = None, - result: ActionReturnType = None - ) -> None: - self.success = success - self.traceback = traceback - self.result = result - self.processed = True - - def process(self, module: Module) -> ActionReturnType: - raise NotImplementedError - - def sleep_until_processed(self, interval=0.2) -> None: - while not self.processed: - time.sleep(interval) - - @dataclass - class UploadCalCurve(BaseAction): - """upload_cal_curve(channel, filename) - - **Task** - Upload a calibration curve to a channel. - - Args - ------ - channel (int): - Channel number, 1-8. - filename (str): - Filename for calibration curve. - """ - - channel: int - filename: str - - def process(self, module: Module) -> ActionReturnType: - log.info(f"Starting upload to channel {self.channel}...") - channel = module.channels[self.channel - 1] - channel.load_curve(self.filename) - time.sleep(0.1) - return None - - @dataclass - class SetValues(BaseAction): - """set_values(channel, sensor=None, auto_range=None, range=None,\ - current_reversal=None, units=None, enabled=None, name=None) - - **Task** - Set sensor parameters for a Lakeshore240 Channel. - - Args - --------- - channel (int): - Channel number to set. Valid choices are 1-8. - sensor (int, optional): - Specifies sensor type. See - :func:`socs.Lakeshore.Lakeshore240.Channel.set_values` for - possible types. - auto_range (int, optional): - Specifies if channel should use autorange. Must be 0 or 1. - range (int, optional): - Specifies range if auto_range is false. Only settable for NTC - RTD. See - :func:`socs.Lakeshore.Lakeshore240.Channel.set_values` for - possible ranges. - current_reversal (int, optional): - Specifies if input current reversal is on or off. - Always 0 if input is a diode. - units (int, optional): - Specifies preferred units parameter, and sets the units for - alarm settings. See - :func:`socs.Lakeshore.Lakeshore240.Channel.set_values` for - possible units. - enabled (int, optional): - Sets if channel is enabled. - name (str, optional): - Sets name of channel. - """ - - channel: int - sensor: Optional[int] = None - auto_range: Optional[int] = None - range: Optional[int] = None - current_reversal: Optional[int] = None - units: Optional[int] = None - enabled: Optional[int] = None - name: Optional[str] = None - - def process(self, module: Module) -> ActionReturnType: - log.info(f"Setting values for channel {self.channel}...") - module.channels[self.channel - 1].set_values( - sensor=self.sensor, - auto_range=self.auto_range, - range=self.range, - current_reversal=self.current_reversal, - unit=self.units, - enabled=self.enabled, - name=self.name, - ) - time.sleep(0.1) - return None - - -def is_instanceable(t: Type) -> bool: - """ - Checks if its possible to run isinstance with a specified type. This is - needed because older version of python don't let you run this on subscripted - generics. + Args + ------ + channel (int): + Channel number, 1-8. + filename (str): + Filename for calibration curve. """ - try: - isinstance(0, t) - return True - except Exception: - return False -def get_param_type(t: Type) -> Optional[Type]: - """ - Takes in a dataclass field type and returns a type that is accepted - by the OCS param decorator. This will return the original type if it - works with isinstance, or will attempt to unwrap an optional type. Other - types are not currently supported. If it fails, it will return None. - """ - origin_type = get_origin(t) - - # Unwrap possible option type - if origin_type == Union: - sub_types = get_args(t) - if len(sub_types) != 2: - return None - if type(None) not in sub_types: - return None - for st in sub_types: - if st is not type(None): - if is_instanceable(st): - return st - - elif is_instanceable(t): - return t - - return None - - -def register_task_from_action( - agent: ocs_agent.OCSAgent, - name: str, - action_class: Type[Actions.BaseAction], - queue: "queue.Queue[Actions.BaseAction]" -) -> None: - """ - Registers an OCSTask from an Action type. This will define ocs_params based - on the dataclass fields, and set the task docstrings equal to the action - class docstrings. + channel: int + filename: str + + def process(self, module: Module) -> None: + log.info(f"Starting upload to channel {self.channel}...") + channel = module.channels[self.channel - 1] + channel.load_curve(self.filename) + time.sleep(0.1) + raise Exception("TEST") + +@dataclass +class SetValues(LS240Action): + """set_values(channel, sensor=None, auto_range=None, range=None,\ + current_reversal=None, units=None, enabled=None, name=None) + + **Task** - Set sensor parameters for a Lakeshore240 Channel. + + Args + --------- + channel (int): + Channel number to set. Valid choices are 1-8. + sensor (int, optional): + Specifies sensor type. See + :func:`socs.Lakeshore.Lakeshore240.Channel.set_values` for + possible types. + auto_range (int, optional): + Specifies if channel should use autorange. Must be 0 or 1. + range (int, optional): + Specifies range if auto_range is false. Only settable for NTC + RTD. See + :func:`socs.Lakeshore.Lakeshore240.Channel.set_values` for + possible ranges. + current_reversal (int, optional): + Specifies if input current reversal is on or off. + Always 0 if input is a diode. + units (int, optional): + Specifies preferred units parameter, and sets the units for + alarm settings. See + :func:`socs.Lakeshore.Lakeshore240.Channel.set_values` for + possible units. + enabled (int, optional): + Sets if channel is enabled. + name (str, optional): + Sets name of channel. """ - def task( - session: ocs_agent.OpSession, - params: Optional[Dict[str, Any]] = None - ) -> OcsOpReturnType: - _params = {} if params is None else params - action = action_class(**_params) - queue.put(action) - action.sleep_until_processed() - - if not action.success: - log.error("{name} failed to process...", name=name) - if action.traceback is not None: - log.error("traceback:\n{traceback}", traceback=action.traceback) - return False, f"{name} failed" - - if action.result is not None: - session.data.update(action.result) - - return True, f"{name} succeded" - - task.__doc__ = action_class.__doc__ - - # Adds ocs parameters - for f in fields(action_class): - param_type = get_param_type(f.type) - if param_type is None: - raise ValueError(f"Unsupported param type for arg {f.name}: {f.type}") - param_kwargs: Dict[str, Any] = { - 'type': param_type, - } - if f.default != dataclasses.MISSING: - param_kwargs['default'] = f.default - if isinstance(f.metadata, dict): - param_kwargs.update(f.metadata.get('ocs_param_kwargs', {})) - task = ocs_agent.param(f.name, **param_kwargs)(task) - - agent.register_task(name, task) - + channel: int + sensor: Optional[int] = None + auto_range: Optional[int] = None + range: Optional[int] = None + current_reversal: Optional[int] = None + units: Optional[int] = None + enabled: Optional[int] = None + name: Optional[str] = None + + def process(self, module: Module) -> None: + log.info(f"Setting values for channel {self.channel}...") + module.channels[self.channel - 1].set_values( + sensor=self.sensor, + auto_range=self.auto_range, + range=self.range, + current_reversal=self.current_reversal, + unit=self.units, + enabled=self.enabled, + name=self.name, + ) + time.sleep(0.1) class LS240_Agent: def __init__( @@ -245,14 +118,17 @@ def __init__( self.agent: ocs_agent.OCSAgent = agent self.port = port self.f_sample = f_sample - self.action_queue: "queue.Queue[Actions.BaseAction]" = queue.Queue() + self.action_queue: "queue.Queue[LS240Action]" = queue.Queue() + + def queue_action(action: LS240Action): + self.action_queue.put(action) # Register Operaionts register_task_from_action( - agent, 'set_values', Actions.SetValues, self.action_queue + agent, "set_values", SetValues, queue_action ) register_task_from_action( - agent, 'upload_cal_curve', Actions.UploadCalCurve, self.action_queue + agent, "upload_cal_curve", UploadCalCurve, queue_action ) agent.register_process("main", self.main, self._stop_main, startup=True) @@ -307,11 +183,15 @@ def _process_actions(self, module: Module) -> None: action = self.action_queue.get() try: log.info(f"Running action {action}") - result = action.process(module) - action.resolve_action(True, result=result) + action.process(module) + action.resolve_action(True) except Exception: # pylint: disable=broad-except log.error(f"Error processing action: {action}") - action.resolve_action(False, traceback=traceback.format_exc()) + action.resolve_action( + False, + traceback=traceback.format_exc(), + return_message="Uncaught Exception" + ) def main( self, @@ -328,7 +208,7 @@ def main( # Clear pre-existing actions while not self.action_queue.empty(): action = self.action_queue.get() - action.resolve_action(False, traceback="Aborted by main process") + action.resolve_action(False, return_message="Aborted by main process") exceptions_to_attempt_reconnect = (ConnectionError, TimeoutError) diff --git a/socs/util.py b/socs/util.py index 6d9ceaefd..0f57f8e7e 100644 --- a/socs/util.py +++ b/socs/util.py @@ -1,9 +1,205 @@ import hashlib +import dataclasses +from typing import ( + Optional, + Dict, + Any, + Type, + Callable, + Tuple, + get_origin, + get_args, + Union, + TypeVar, +) +import time + +from ocs import ocs_agent def get_md5sum(filename): m = hashlib.md5() - for line in open(filename, 'rb'): + for line in open(filename, "rb"): m.update(line) return m.hexdigest() + + +ActionResultType = Optional[Dict[str, Any]] +OcsOpReturnType = Tuple[bool, str] + + +@dataclasses.dataclass +class BaseAction: + """ + Base subclass for actions that correspond to OCS tasks. Such actions can + be used to generate a generic task that creates the action, passes + to some callback function, and waits until the action is resolved before + returning. + """ + + def __post_init__(self) -> None: + self._session_data: Dict[str, Any] = {} + self._processed: bool = False + self._success: bool = False + self._traceback: Optional[str] = None + self._return_message: Optional[str] = None + + def resolve_action( + self, + success: bool, + traceback: Optional[str] = None, + return_message: Optional[str] = None, + ) -> None: + """ + Resolves an action, signifying it has been completed. Tasks waiting for + this action to be resolved can then return. + """ + self._success = success + if traceback is not None: + self._traceback = traceback + if return_message is not None: + self._return_message = return_message + self._processed = True + + def update_session_data(self, data: Dict[str, Any]): + self._session_data.update(data) + + def sleep_until_resolved(self, session: ocs_agent.OpSession, interval=0.2) -> None: + """ + Sleeps until the action has been resolved. + """ + while not self._processed: + session.data = self._session_data + time.sleep(interval) + + session.data = self._session_data + + +def is_instanceable(t: Type) -> bool: + """ + Checks if its possible to run isinstance with a specified type. This is + needed because older version of python don't let you run this on subscripted + generics. + """ + try: + isinstance(0, t) + return True + except Exception: + return False + + +def get_param_type(t: Type) -> Optional[Type]: + """ + OCS param type variables require you to be able to run isinstance, + which does not work for subscripted generics like Optional in python3.8. + This function attempts to convert types to values that will be accepted + by ocs_agent.param, unwrapping optional types if we are unable to run + isinstance off the bat. + + Other subscripted generics such as List[...] or Dict[...] are not currently + supported. + + This function will return the unwrapped type, or None if it fails. + """ + origin_type = get_origin(t) + + # Unwrap possible option type + if origin_type == Union: + sub_types = get_args(t) + if len(sub_types) != 2: + return None + if type(None) not in sub_types: + return None + for st in sub_types: + if st is not type(None): + if is_instanceable(st): + return st + + elif is_instanceable(t): + # If this works, then it should work with ocs_agent.param + return t + + return None + + +BaseActionT = TypeVar("BaseActionT", bound=BaseAction) + + +def register_task_from_action( + agent: ocs_agent.OCSAgent, + name: str, + action_class: Type[BaseActionT], + callback: Callable[[BaseActionT], Any], +) -> None: + """ + Registers a generic OCS task based on an Action dataclass. This will + automatically set OCS parameters, and the docstrings based on the dataclass + fields and the Action class docstrings. + + The generic task will always do the following: + - generate an instance of the action class based on passed in params + - pass the action to the supplied callback function + - wait until the action is resolved by the agent, regularly updating + session.data. + - If the action is resolved successfully, will return a successful result. + If the action failed, this will log the action traceback if it is set, + and return a failed result. + + Args + -------- + agent: OCSAgent + OCS agent to use to register the task + name: str + Name of the task. This will be used to set the operation endpoint. + action_class: Type[BaseActionT] + The class to be used to generate the task. + callback: Callable[[BaseActionT], Any] + Function to call with the action instance after it is created. + It is expected that after calling this, the task will eventually + be processed and resolved by the agent. + """ + + def task( + session: ocs_agent.OpSession, params: Optional[Dict[str, Any]] = None + ) -> OcsOpReturnType: + _params: Dict[str, Any] = {} if params is None else params + action: BaseActionT = action_class(**_params) + callback(action) + action.sleep_until_resolved(session) + + if action._success: + if action._return_message is None: + return_message = f"{name} successful" + else: + return_message = action._return_message + return True, return_message + + else: + if action._return_message is None: + return_message = f"{name} failed" + else: + return_message = action._return_message + + if action._traceback is not None: + agent.log.error("traceback:\n{traceback}", traceback=action._traceback) + + return False, return_message + + task.__doc__ = action_class.__doc__ + + # Adds ocs parameters + for f in dataclasses.fields(action_class): + param_type = get_param_type(f.type) + if param_type is None: + raise ValueError(f"Unsupported param type for arg {f.name}: {f.type}") + param_kwargs: Dict[str, Any] = { + "type": param_type, + } + if f.default != dataclasses.MISSING: + param_kwargs["default"] = f.default + if isinstance(f.metadata, dict): + param_kwargs.update(f.metadata.get("ocs_param_kwargs", {})) + task = ocs_agent.param(f.name, **param_kwargs)(task) + + agent.register_task(name, task) diff --git a/tests/integration/test_ls240_agent_integration.py b/tests/integration/test_ls240_agent_integration.py index 15bcd4ef8..b37e53698 100644 --- a/tests/integration/test_ls240_agent_integration.py +++ b/tests/integration/test_ls240_agent_integration.py @@ -7,7 +7,7 @@ from ocs.base import OpCode from ocs.testing import create_agent_runner_fixture, create_client_fixture -from socs.agents.lakeshore240.agent import Actions as LS240Actions +from socs.agents.lakeshore240.agent import SetValues, UploadCalCurve from socs.testing.device_emulator import DeviceEmulator, create_device_emulator wait_for_crossbar = create_crossbar_fixture() @@ -70,7 +70,7 @@ def test_ls240_set_values(wait_for_crossbar, emulator: DeviceEmulator, run_agent 'INTYPE 1,1,1,0,0,1,1': ''} emulator.update_responses(responses) - set_values_params = LS240Actions.SetValues( + set_values_params = SetValues( channel=1, sensor=1, auto_range=1, range=0, current_reversal=0, units=1, enabled=1, name="Channel 01" ) @@ -105,7 +105,7 @@ def test_ls240_upload_cal_curve(wait_for_crossbar, emulator, run_agent, client, # No queries are sent during upload, so rely on the default response of '' responses = {'CRVHDR 1,DT-670-SD-1.4L,D60STND,2,325.0,1': ''} emulator.update_responses(responses) - upload_cal_curve_params = LS240Actions.UploadCalCurve( + upload_cal_curve_params = UploadCalCurve( channel=1, filename=str(cal_file) ) resp = client.upload_cal_curve(**asdict(upload_cal_curve_params))