Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

magicbot: Allow typed feedbacks using return type hints #208

Merged
merged 11 commits into from
Aug 7, 2024
118 changes: 103 additions & 15 deletions magicbot/magic_tunable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import collections.abc
import functools
import inspect
import typing
import warnings
from typing import Callable, Generic, Optional, TypeVar, overload
from typing import Callable, Generic, Optional, Sequence, TypeVar, Union, overload

from ntcore import NetworkTableInstance, Value
import ntcore
from ntcore import NetworkTableInstance
from ntcore.types import ValueT


class StructSerializable(typing.Protocol):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This protocol probably shouldn't belong here...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should end up in wpimath, but I don't see why it can't be in two places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume you mean wpiutil? I'll follow up with a PR over there then.

"""Any type that is a wpiutil.wpistruct."""

WPIStruct: typing.ClassVar


T = TypeVar("T")
V = TypeVar("V", bound=ValueT)
V = TypeVar("V", bound=Union[ValueT, StructSerializable, Sequence[StructSerializable]])


class tunable(Generic[V]):
Expand Down Expand Up @@ -50,6 +60,10 @@ def execute(self):
you will want to use setup_tunables to set the object up.
In normal usage, MagicRobot does this for you, so you don't
have to do anything special.

.. versionchanged:: 2024.1.0
Added support for WPILib Struct serializable types.
Integer defaults now create integer topics instead of double topics.
"""

# the way this works is we use a special class to indicate that it
Expand All @@ -66,7 +80,7 @@ def execute(self):
"_ntsubtable",
"_ntwritedefault",
# "__doc__",
"_mkv",
"_topic_type",
"_nt",
)

Expand All @@ -84,10 +98,15 @@ def __init__(
self._ntdefault = default
self._ntsubtable = subtable
self._ntwritedefault = writeDefault
d = Value.makeValue(default)
self._mkv = Value.getFactoryByType(d.type())
# self.__doc__ = doc

self._topic_type = _get_topic_type_for_value(self._ntdefault)
if self._topic_type is None:
checked_type: type = type(self._ntdefault)
raise TypeError(
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
)

@overload
def __get__(self, instance: None, owner=None) -> "tunable[V]": ...

Expand All @@ -96,11 +115,23 @@ def __get__(self, instance, owner=None) -> V: ...

def __get__(self, instance, owner=None):
if instance is not None:
return instance._tunables[self].value
return instance._tunables[self].get()
return self

def __set__(self, instance, value: V) -> None:
instance._tunables[self].setValue(self._mkv(value))
instance._tunables[self].set(value)


def _get_topic_type_for_value(value) -> Optional[Callable[[ntcore.Topic], typing.Any]]:
topic_type = _get_topic_type(type(value))
# bytes and str are Sequences. They must be checked before Sequence.
if topic_type is None and isinstance(value, collections.abc.Sequence):
if not value:
raise ValueError(
f"tunable default cannot be an empty sequence, got {value}"
)
topic_type = _get_topic_type(Sequence[type(value[0])]) # type: ignore [misc]
return topic_type


def setup_tunables(component, cname: str, prefix: Optional[str] = "components") -> None:
Expand All @@ -124,7 +155,7 @@ def setup_tunables(component, cname: str, prefix: Optional[str] = "components")

NetworkTables = NetworkTableInstance.getDefault()

tunables = {}
tunables: dict[tunable, ntcore.Topic] = {}

for n in dir(cls):
if n.startswith("_"):
Expand All @@ -139,11 +170,12 @@ def setup_tunables(component, cname: str, prefix: Optional[str] = "components")
else:
key = "%s/%s" % (prefix, n)

ntvalue = NetworkTables.getEntry(key)
topic = prop._topic_type(NetworkTables.getTopic(key))
ntvalue = topic.getEntry(prop._ntdefault)
if prop._ntwritedefault:
ntvalue.setValue(prop._ntdefault)
ntvalue.set(prop._ntdefault)
else:
ntvalue.setDefaultValue(prop._ntdefault)
ntvalue.setDefault(prop._ntdefault)
tunables[prop] = ntvalue

component._tunables = tunables
Expand Down Expand Up @@ -201,6 +233,10 @@ class MyRobot(magicbot.MagicRobot):
especially if you wish to monitor WPILib objects.

.. versionadded:: 2018.1.0

.. versionchanged:: 2024.1.0
WPILib Struct serializable types are supported when the return type is type hinted.
An ``int`` return type hint now creates an integer topic.
"""
if f is None:
return functools.partial(feedback, key=key)
Expand All @@ -222,10 +258,50 @@ class MyRobot(magicbot.MagicRobot):
return f


_topic_types = {
bool: ntcore.BooleanTopic,
int: ntcore.IntegerTopic,
float: ntcore.DoubleTopic,
str: ntcore.StringTopic,
bytes: ntcore.RawTopic,
}
_array_topic_types = {
bool: ntcore.BooleanArrayTopic,
int: ntcore.IntegerArrayTopic,
float: ntcore.DoubleArrayTopic,
str: ntcore.StringArrayTopic,
}


def _get_topic_type(
return_annotation,
) -> Optional[Callable[[ntcore.Topic], typing.Any]]:
if return_annotation in _topic_types:
return _topic_types[return_annotation]
if hasattr(return_annotation, "WPIStruct"):
return lambda topic: ntcore.StructTopic(topic, return_annotation)

# Check for PEP 484 generic types
origin = getattr(return_annotation, "__origin__", None)
args = typing.get_args(return_annotation)
if origin in (list, tuple, collections.abc.Sequence) and args:
# Ensure tuples are tuple[T, ...] or homogenous
if origin is tuple and not (
(len(args) == 2 and args[1] is Ellipsis) or len(set(args)) == 1
):
return None

inner_type = args[0]
if inner_type in _array_topic_types:
return _array_topic_types[inner_type]
if hasattr(inner_type, "WPIStruct"):
return lambda topic: ntcore.StructArrayTopic(topic, inner_type)


def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components"):
"""
Finds all methods decorated with :func:`feedback` on an object
and returns a list of 2-tuples (method, NetworkTables entry).
and returns a list of 2-tuples (method, NetworkTables entry setter).

.. note:: This isn't useful for normal use.
"""
Expand All @@ -246,7 +322,19 @@ def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components
else:
key = name

entry = nt.getEntry(key)
feedbacks.append((method, entry))
return_annotation = typing.get_type_hints(method).get("return", None)
if return_annotation is not None:
topic_type = _get_topic_type(return_annotation)
else:
topic_type = None

if topic_type is None:
entry = nt.getEntry(key)
setter = entry.setValue
else:
publisher = topic_type(nt.getTopic(key)).publish()
setter = publisher.set

feedbacks.append((method, setter))

return feedbacks
8 changes: 4 additions & 4 deletions magicbot/magicrobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import hal
import wpilib

from ntcore import NetworkTableInstance, NetworkTableEntry
from ntcore import NetworkTableInstance

# from wpilib.shuffleboard import Shuffleboard

Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self) -> None:
self.__last_error_report = -10

self._components: List[Tuple[str, Any]] = []
self._feedbacks: List[Tuple[Callable[[], Any], NetworkTableEntry]] = []
self._feedbacks: List[Tuple[Callable[[], Any], Callable[[Any], Any]]] = []
self._reset_components: List[Tuple[Dict[str, Any], Any]] = []

self.__done = False
Expand Down Expand Up @@ -720,13 +720,13 @@ def _do_periodics(self) -> None:
"""Run periodic methods which run in every mode."""
watchdog = self.watchdog

for method, entry in self._feedbacks:
for method, setter in self._feedbacks:
try:
value = method()
except:
self.onException()
else:
entry.setValue(value)
setter(value)

watchdog.addEpoch("@magicbot.feedback")

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ zip_safe = False
include_package_data = True
packages = find:
install_requires =
wpilib>=2024.1.1.0,<2025
wpilib>=2024.3.2.1,<2025
setup_requires =
setuptools_scm > 6
python_requires = >=3.8
Expand Down
105 changes: 105 additions & 0 deletions tests/test_magicbot_feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Sequence, Tuple

import ntcore
from wpimath import geometry

import magicbot


class BasicComponent:
@magicbot.feedback
def get_number(self):
return 0

@magicbot.feedback
def get_ints(self):
return (0,)

@magicbot.feedback
def get_floats(self):
return (0.0, 0)

def execute(self):
pass


class TypeHintedComponent:
@magicbot.feedback
def get_rotation(self) -> geometry.Rotation2d:
return geometry.Rotation2d()

@magicbot.feedback
def get_rotation_array(self) -> Sequence[geometry.Rotation2d]:
return [geometry.Rotation2d()]

@magicbot.feedback
def get_rotation_2_tuple(self) -> Tuple[geometry.Rotation2d, geometry.Rotation2d]:
return (geometry.Rotation2d(), geometry.Rotation2d())

@magicbot.feedback
def get_int(self) -> int:
return 0

@magicbot.feedback
def get_float(self) -> float:
return 0.5

@magicbot.feedback
def get_ints(self) -> Sequence[int]:
return (0,)

@magicbot.feedback
def get_empty_strings(self) -> Sequence[str]:
return ()

def execute(self):
pass


class Robot(magicbot.MagicRobot):
basic: BasicComponent
type_hinted: TypeHintedComponent

def createObjects(self):
pass


def test_feedbacks_with_type_hints():
robot = Robot()
robot.robotInit()
nt = ntcore.NetworkTableInstance.getDefault().getTable("components")

robot._do_periodics()

for name, type_str, value in (
("basic/number", "double", 0.0),
("basic/ints", "int[]", [0]),
("basic/floats", "double[]", [0.0, 0.0]),
("type_hinted/int", "int", 0),
("type_hinted/float", "double", 0.5),
("type_hinted/ints", "int[]", [0]),
("type_hinted/empty_strings", "string[]", []),
):
topic = nt.getTopic(name)
assert topic.getTypeString() == type_str
assert topic.genericSubscribe().get().value() == value

for name, value in [
("type_hinted/rotation", geometry.Rotation2d()),
]:
struct_type = type(value)
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}"
topic = nt.getStructTopic(name, struct_type)
assert topic.subscribe(None).get() == value

for name, struct_type, value in (
("type_hinted/rotation_array", geometry.Rotation2d, [geometry.Rotation2d()]),
(
"type_hinted/rotation_2_tuple",
geometry.Rotation2d,
[geometry.Rotation2d(), geometry.Rotation2d()],
),
):
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}[]"
topic = nt.getStructArrayTopic(name, struct_type)
assert topic.subscribe([]).get() == value
Loading
Loading