diff --git a/magicbot/magic_tunable.py b/magicbot/magic_tunable.py index 78e8241..d8e3570 100644 --- a/magicbot/magic_tunable.py +++ b/magicbot/magic_tunable.py @@ -1,13 +1,23 @@ +import collections.abc import functools import inspect +import typing import warnings -from typing import Callable, Generic, Optional, TypeVar, overload +from typing import Callable, Generic, Optional, Sequence, TypeVar, Union, overload -from ntcore import NetworkTableInstance, Value +import ntcore +from ntcore import NetworkTableInstance from ntcore.types import ValueT + +class StructSerializable(typing.Protocol): + """Any type that is a wpiutil.wpistruct.""" + + WPIStruct: typing.ClassVar + + T = TypeVar("T") -V = TypeVar("V", bound=ValueT) +V = TypeVar("V", bound=Union[ValueT, StructSerializable, Sequence[StructSerializable]]) class tunable(Generic[V]): @@ -50,6 +60,10 @@ def execute(self): you will want to use setup_tunables to set the object up. In normal usage, MagicRobot does this for you, so you don't have to do anything special. + + .. versionchanged:: 2024.1.0 + Added support for WPILib Struct serializable types. + Integer defaults now create integer topics instead of double topics. """ # the way this works is we use a special class to indicate that it @@ -66,7 +80,7 @@ def execute(self): "_ntsubtable", "_ntwritedefault", # "__doc__", - "_mkv", + "_topic_type", "_nt", ) @@ -84,10 +98,15 @@ def __init__( self._ntdefault = default self._ntsubtable = subtable self._ntwritedefault = writeDefault - d = Value.makeValue(default) - self._mkv = Value.getFactoryByType(d.type()) # self.__doc__ = doc + self._topic_type = _get_topic_type_for_value(self._ntdefault) + if self._topic_type is None: + checked_type: type = type(self._ntdefault) + raise TypeError( + f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}" + ) + @overload def __get__(self, instance: None, owner=None) -> "tunable[V]": ... @@ -96,11 +115,23 @@ def __get__(self, instance, owner=None) -> V: ... def __get__(self, instance, owner=None): if instance is not None: - return instance._tunables[self].value + return instance._tunables[self].get() return self def __set__(self, instance, value: V) -> None: - instance._tunables[self].setValue(self._mkv(value)) + instance._tunables[self].set(value) + + +def _get_topic_type_for_value(value) -> Optional[Callable[[ntcore.Topic], typing.Any]]: + topic_type = _get_topic_type(type(value)) + # bytes and str are Sequences. They must be checked before Sequence. + if topic_type is None and isinstance(value, collections.abc.Sequence): + if not value: + raise ValueError( + f"tunable default cannot be an empty sequence, got {value}" + ) + topic_type = _get_topic_type(Sequence[type(value[0])]) # type: ignore [misc] + return topic_type def setup_tunables(component, cname: str, prefix: Optional[str] = "components") -> None: @@ -124,7 +155,7 @@ def setup_tunables(component, cname: str, prefix: Optional[str] = "components") NetworkTables = NetworkTableInstance.getDefault() - tunables = {} + tunables: dict[tunable, ntcore.Topic] = {} for n in dir(cls): if n.startswith("_"): @@ -139,11 +170,12 @@ def setup_tunables(component, cname: str, prefix: Optional[str] = "components") else: key = "%s/%s" % (prefix, n) - ntvalue = NetworkTables.getEntry(key) + topic = prop._topic_type(NetworkTables.getTopic(key)) + ntvalue = topic.getEntry(prop._ntdefault) if prop._ntwritedefault: - ntvalue.setValue(prop._ntdefault) + ntvalue.set(prop._ntdefault) else: - ntvalue.setDefaultValue(prop._ntdefault) + ntvalue.setDefault(prop._ntdefault) tunables[prop] = ntvalue component._tunables = tunables @@ -201,6 +233,10 @@ class MyRobot(magicbot.MagicRobot): especially if you wish to monitor WPILib objects. .. versionadded:: 2018.1.0 + + .. versionchanged:: 2024.1.0 + WPILib Struct serializable types are supported when the return type is type hinted. + An ``int`` return type hint now creates an integer topic. """ if f is None: return functools.partial(feedback, key=key) @@ -222,10 +258,50 @@ class MyRobot(magicbot.MagicRobot): return f +_topic_types = { + bool: ntcore.BooleanTopic, + int: ntcore.IntegerTopic, + float: ntcore.DoubleTopic, + str: ntcore.StringTopic, + bytes: ntcore.RawTopic, +} +_array_topic_types = { + bool: ntcore.BooleanArrayTopic, + int: ntcore.IntegerArrayTopic, + float: ntcore.DoubleArrayTopic, + str: ntcore.StringArrayTopic, +} + + +def _get_topic_type( + return_annotation, +) -> Optional[Callable[[ntcore.Topic], typing.Any]]: + if return_annotation in _topic_types: + return _topic_types[return_annotation] + if hasattr(return_annotation, "WPIStruct"): + return lambda topic: ntcore.StructTopic(topic, return_annotation) + + # Check for PEP 484 generic types + origin = getattr(return_annotation, "__origin__", None) + args = typing.get_args(return_annotation) + if origin in (list, tuple, collections.abc.Sequence) and args: + # Ensure tuples are tuple[T, ...] or homogenous + if origin is tuple and not ( + (len(args) == 2 and args[1] is Ellipsis) or len(set(args)) == 1 + ): + return None + + inner_type = args[0] + if inner_type in _array_topic_types: + return _array_topic_types[inner_type] + if hasattr(inner_type, "WPIStruct"): + return lambda topic: ntcore.StructArrayTopic(topic, inner_type) + + def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components"): """ Finds all methods decorated with :func:`feedback` on an object - and returns a list of 2-tuples (method, NetworkTables entry). + and returns a list of 2-tuples (method, NetworkTables entry setter). .. note:: This isn't useful for normal use. """ @@ -246,7 +322,19 @@ def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components else: key = name - entry = nt.getEntry(key) - feedbacks.append((method, entry)) + return_annotation = typing.get_type_hints(method).get("return", None) + if return_annotation is not None: + topic_type = _get_topic_type(return_annotation) + else: + topic_type = None + + if topic_type is None: + entry = nt.getEntry(key) + setter = entry.setValue + else: + publisher = topic_type(nt.getTopic(key)).publish() + setter = publisher.set + + feedbacks.append((method, setter)) return feedbacks diff --git a/magicbot/magicrobot.py b/magicbot/magicrobot.py index 2f3249e..a109f24 100644 --- a/magicbot/magicrobot.py +++ b/magicbot/magicrobot.py @@ -10,7 +10,7 @@ import hal import wpilib -from ntcore import NetworkTableInstance, NetworkTableEntry +from ntcore import NetworkTableInstance # from wpilib.shuffleboard import Shuffleboard @@ -73,7 +73,7 @@ def __init__(self) -> None: self.__last_error_report = -10 self._components: List[Tuple[str, Any]] = [] - self._feedbacks: List[Tuple[Callable[[], Any], NetworkTableEntry]] = [] + self._feedbacks: List[Tuple[Callable[[], Any], Callable[[Any], Any]]] = [] self._reset_components: List[Tuple[Dict[str, Any], Any]] = [] self.__done = False @@ -720,13 +720,13 @@ def _do_periodics(self) -> None: """Run periodic methods which run in every mode.""" watchdog = self.watchdog - for method, entry in self._feedbacks: + for method, setter in self._feedbacks: try: value = method() except: self.onException() else: - entry.setValue(value) + setter(value) watchdog.addEpoch("@magicbot.feedback") diff --git a/setup.cfg b/setup.cfg index bd7606e..301fb05 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ zip_safe = False include_package_data = True packages = find: install_requires = - wpilib>=2024.1.1.0,<2025 + wpilib>=2024.3.2.1,<2025 setup_requires = setuptools_scm > 6 python_requires = >=3.8 diff --git a/tests/test_magicbot_feedback.py b/tests/test_magicbot_feedback.py new file mode 100644 index 0000000..3b17197 --- /dev/null +++ b/tests/test_magicbot_feedback.py @@ -0,0 +1,105 @@ +from typing import Sequence, Tuple + +import ntcore +from wpimath import geometry + +import magicbot + + +class BasicComponent: + @magicbot.feedback + def get_number(self): + return 0 + + @magicbot.feedback + def get_ints(self): + return (0,) + + @magicbot.feedback + def get_floats(self): + return (0.0, 0) + + def execute(self): + pass + + +class TypeHintedComponent: + @magicbot.feedback + def get_rotation(self) -> geometry.Rotation2d: + return geometry.Rotation2d() + + @magicbot.feedback + def get_rotation_array(self) -> Sequence[geometry.Rotation2d]: + return [geometry.Rotation2d()] + + @magicbot.feedback + def get_rotation_2_tuple(self) -> Tuple[geometry.Rotation2d, geometry.Rotation2d]: + return (geometry.Rotation2d(), geometry.Rotation2d()) + + @magicbot.feedback + def get_int(self) -> int: + return 0 + + @magicbot.feedback + def get_float(self) -> float: + return 0.5 + + @magicbot.feedback + def get_ints(self) -> Sequence[int]: + return (0,) + + @magicbot.feedback + def get_empty_strings(self) -> Sequence[str]: + return () + + def execute(self): + pass + + +class Robot(magicbot.MagicRobot): + basic: BasicComponent + type_hinted: TypeHintedComponent + + def createObjects(self): + pass + + +def test_feedbacks_with_type_hints(): + robot = Robot() + robot.robotInit() + nt = ntcore.NetworkTableInstance.getDefault().getTable("components") + + robot._do_periodics() + + for name, type_str, value in ( + ("basic/number", "double", 0.0), + ("basic/ints", "int[]", [0]), + ("basic/floats", "double[]", [0.0, 0.0]), + ("type_hinted/int", "int", 0), + ("type_hinted/float", "double", 0.5), + ("type_hinted/ints", "int[]", [0]), + ("type_hinted/empty_strings", "string[]", []), + ): + topic = nt.getTopic(name) + assert topic.getTypeString() == type_str + assert topic.genericSubscribe().get().value() == value + + for name, value in [ + ("type_hinted/rotation", geometry.Rotation2d()), + ]: + struct_type = type(value) + assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}" + topic = nt.getStructTopic(name, struct_type) + assert topic.subscribe(None).get() == value + + for name, struct_type, value in ( + ("type_hinted/rotation_array", geometry.Rotation2d, [geometry.Rotation2d()]), + ( + "type_hinted/rotation_2_tuple", + geometry.Rotation2d, + [geometry.Rotation2d(), geometry.Rotation2d()], + ), + ): + assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}[]" + topic = nt.getStructArrayTopic(name, struct_type) + assert topic.subscribe([]).get() == value diff --git a/tests/test_magicbot_tunable.py b/tests/test_magicbot_tunable.py new file mode 100644 index 0000000..fad999b --- /dev/null +++ b/tests/test_magicbot_tunable.py @@ -0,0 +1,56 @@ +import ntcore +import pytest +from wpimath import geometry + +from magicbot.magic_tunable import setup_tunables, tunable + + +def test_tunable() -> None: + class Component: + an_int = tunable(1) + ints = tunable([0]) + floats = tunable([1.0, 2.0]) + rotation = tunable(geometry.Rotation2d()) + rotations = tunable([geometry.Rotation2d()]) + + component = Component() + setup_tunables(component, "test_tunable") + nt = ntcore.NetworkTableInstance.getDefault().getTable("/components/test_tunable") + + for name, type_str, value in [ + ("an_int", "int", 1), + ("ints", "int[]", [0]), + ("floats", "double[]", [1.0, 2.0]), + ]: + topic = nt.getTopic(name) + assert topic.getTypeString() == type_str + assert topic.genericSubscribe().get().value() == value + + for name, value in [ + ("rotation", geometry.Rotation2d()), + ]: + struct_type = type(value) + assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}" + topic = nt.getStructTopic(name, struct_type) + assert topic.subscribe(None).get() == value + + for name, struct_type, value in [ + ("rotations", geometry.Rotation2d, [geometry.Rotation2d()]), + ]: + assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}[]" + topic = nt.getStructArrayTopic(name, struct_type) + assert topic.subscribe([]).get() == value + + +def test_tunable_errors(): + with pytest.raises(TypeError): + + class Component: + invalid = tunable(None) + + +def test_tunable_errors_with_empty_sequence(): + with pytest.raises(ValueError): + + class Component: + empty = tunable([])