diff --git a/.pylintrc b/.pylintrc
index de98f856..8ff70cfb 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -197,24 +197,42 @@ good-names=a,
b,
c,
d,
+ f,
i,
j,
k,
+ m,
+ M,
+ n,
+ p,
+ ps,
x,
+ x0,
+ x1,
+ X,
y,
+ y0,
+ y1,
z,
u,
+ us,
v,
+ vs,
w,
h,
r,
rc,
+ S,
+ S_inv,
+ t,
ax,
ex,
hz,
kw,
ns,
Run,
+ train_X,
+ test_X,
_
# Good variable names regexes, separated by a comma. If names match any regex,
diff --git a/ada_feeding_msgs/CMakeLists.txt b/ada_feeding_msgs/CMakeLists.txt
index 8f8b6bd4..9447718d 100644
--- a/ada_feeding_msgs/CMakeLists.txt
+++ b/ada_feeding_msgs/CMakeLists.txt
@@ -19,6 +19,7 @@ find_package(rosidl_default_generators REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/AcquisitionSchema.msg"
"msg/FaceDetection.msg"
+ "msg/FoodOnForkDetection.msg"
"msg/Mask.msg"
"action/AcquireFood.action"
diff --git a/ada_feeding_msgs/msg/FoodOnForkDetection.msg b/ada_feeding_msgs/msg/FoodOnForkDetection.msg
new file mode 100644
index 00000000..fd91d003
--- /dev/null
+++ b/ada_feeding_msgs/msg/FoodOnForkDetection.msg
@@ -0,0 +1,18 @@
+# A message with the results of food on fork detection on a single frame.
+
+# The header for the image the detection corresponds to
+std_msgs/Header header
+
+# The status of the food-on-fork detector.
+int32 status
+int32 SUCCESS=1
+int32 ERROR_TOO_FEW_POINTS=-1
+int32 ERROR_NO_TRANSFORM=-2
+int32 UNKNOWN_ERROR=-99
+
+# A probability in [0,1] that indicates the likelihood that there is food on the
+# fork in the image. Only relevant if status == FoodOnForkDetection.SUCCESS
+float64 probability
+
+# Contains more details of the result, including any error messages that were encountered
+string message
diff --git a/ada_feeding_perception/README.md b/ada_feeding_perception/README.md
index ad7e5aaf..6dcd37a1 100644
--- a/ada_feeding_perception/README.md
+++ b/ada_feeding_perception/README.md
@@ -91,3 +91,35 @@ Launch the web app along with all the other nodes (real or dummy) as documented
- `offline.images` (list of strings, required): The paths, relative to `install/ada_feeding_perception/share/ada_feeding_perception`, to the images to test.
- `offline.point_xs` (list of ints, required): The x-coordinates of the seed points. Must be the same length as `offline.images`.
- `offline.point_ys` (list of ints, required): The y-coordinates of the seed points. Must be the same length as `offline.images`.
+
+## Food-on-Fork Detection
+
+Our eye-in-hand Food-on-Fork Detection node and training/testing infrastructure was designed to make it easy to substitute and compare other food-on-fork detectors. Below are instructions on how to do so.
+
+1. **Developing a new food-on-fork detector**: Create a subclass of `FoodOnForkDetector` that implements all of the abstractmethods. Note that as of now, a model does not have access to a real-time TF Buffer during test time; hence, **all transforms that the model relies on must be static**.
+2. **Gather the dataset**: Because this node uses the eye-in-hand camera, it is sensitive to the relative pose between the camera and the fork. If you are using PRL's robot, [the dataset collected in early 2024](https://drive.google.com/drive/folders/1hNciBOmuHKd67Pw6oAvj_iN_rY1M8ZV0?usp=drive_link) may be sufficient. Otherwise, you should collect your own dataset:
+ 1. The dataset should consist of a series of ROS2 bags, each recording the following: (a) the aligned depth to color image topic; (b) the color image topic; (c) the camera info topic (we assume it is the same for both); and (d) the TF topic(s).
+ 2. We recorded three types of bags: (a) bags where the robot was going through the motions of feeding without food on the fork and without the fork nearing a person or plate; (b) the same as above but with food on the fork; and (c) bags where the robot was acquiring and feeding a bite to someone. We used the first two types of bags for training, and the third type of bag for evaluation.
+ 3. All ROS2 bags should be in the same directory, with a file `bags_metadata.csv` at the top-level of that directory.
+ 4. `bags_metadata.csv` contains the following columns: `rosbag_name` (str), `time_from_start` (float), `food_on_fork` (0/1), `arm_moving` (0/1). The file only needs rows for timestamps when one or both of the latter columns change; for intermediate timestamps, it is assumed that they stay the same.
+ 5. To generate `bags_metadata.csv`, we recommend launching RVIZ, adding your depth and/or RGB image topic, and playing back the bag. e.g.,
+ 1. `ros2 run rviz2 rviz2 --ros-args -p use_sim_time:=true`
+ 2. `ros2 bag play 2024_03_01_two_bites_3 --clock`
+ 3. Pause and play the rosbag script when food foes on/off the fork, and when the arm starts/stops moving, and populate `bags_metadata.csv` accordingly (elapsed time since bag start should be visible at the bottom of RVIZ2).
+3. **Train/test the model on offline data**: We provide a flexible Python script, `food_on_fork_train_test.py`, to train, test, and/or compare one-or-more food-on-fork models. To use it, first ensure you have built and sourced your workspace, and you are in the directory that contains the script (e.g., `cd ~/colcon_ws/src/ada_feeding/ada_feeding_perception/ada_feeding_perception`). To enable flexible use, the script has **many** command-line arguments; we recommend you read their descriptions with `python3 food_on_fork_train_test.py -h`. For reference, we include the command we used to train our model below:
+ ```
+ python3 food_on_fork_train_test.py --model-classes '{"distance_no_fof_detector_with_filters": "ada_feeding_perception.food_on_fork_detectors.FoodOnForkDistanceToNoFOFDetector"}' --model-kwargs '{"distance_no_fof_detector_with_filters": {"camera_matrix": [614.5933227539062, 0.0, 312.1358947753906, 0.0, 614.6914672851562, 223.70831298828125, 0.0, 0.0, 1.0], "min_distance": 0.001}}' --lower-thresh 0.25 --upper-thresh 0.75 --train-set-size 0.5 --crop-top-left 344 272 --crop-bottom-right 408 336 --depth-min-mm 310 --depth-max-mm 340 --rosbags-select 2024_03_01_no_fof 2024_03_01_no_fof_1 2024_03_01_no_fof_2 2024_03_01_no_fof_3 2024_03_01_no_fof_4 2024_03_01_fof_cantaloupe_1 2024_03_01_fof_cantaloupe_2 2024_03_01_fof_cantaloupe_3 2024_03_01_fof_strawberry_1 2024_03_01_fof_strawberry_2 2024_03_01_fof_strawberry_3 2024_02_29_no_fof 2024_02_29_fof_cantaloupe 2024_02_29_fof_strawberry --seed 42 --temporal-window-size 5 --spatial-num-pixels 10
+ ```
+Note that we trained our model on data where the fork either had or didn't have food the whole time, and didn't near any objects (e.g., the plate or the user's mouth). (Also, note that not all the above ROS2 bags are necessary; we've trained accurate detectors with half of them.) We then did an offline evaluation of the model on bags of actual feeding data:
+ ```
+ python3 food_on_fork_train_test.py --model-classes '{"distance_no_fof_detector_with_filters": "ada_feeding_perception.food_on_fork_detectors.FoodOnForkDistanceToNoFOFDetector"}' --model-kwargs '{"distance_no_fof_detector_with_filters": {"camera_matrix": [614.5933227539062, 0.0, 312.1358947753906, 0.0, 614.6914672851562, 223.70831298828125, 0.0, 0.0, 1.0], "min_distance": 0.001}}' --lower-thresh 0.25 --upper-thresh 0.75 --train-set-size 0.5 --crop-top-left 308 248 --crop-bottom-right 436 332 --depth-min-mm 310 --depth-max-mm 340 --rosbags-select 2024_03_01_two_bites 2024_03_01_two_bites_2 2024_03_01_two_bites_3 2024_02_29_two_bites --seed 42 --temporal-window-size 5 --spatial-num-pixels 10 --no-train
+ ```
+4. **Test the model on online data**: First, copy the parameters you used when training your model, as well as the filename of the saved model, to `config/food_on_fork_detection.yaml`. Re-build and source your workspace.
+ 1. **Live Robot**:
+ 1. Launch the robot as usual; the `ada_feeding_perception`launchfile will launch food-on-fork detection.
+ 2. Toggle food-on-fork detection on: `ros2 service call /toggle_food_on_fork_detection std_srvs/srv/SetBool "{data: true}"`
+ 3. Echo the output of food-on-fork detection: `ros2 topic echo /food_on_fork_detection`
+ 2. **ROS2 bag data**:
+ 1. Launch perception: `ros2 launch ada_feeding_perception ada_feeding_perception.launch.py`
+ 2. Toggle food-on-fork detection on and echo the output of food-on-fork detection, as documented above.
+ 4. Launch RVIZ and play back a ROS2 bag, as documented above.
diff --git a/ada_feeding_perception/ada_feeding_perception/depth_post_processors.py b/ada_feeding_perception/ada_feeding_perception/depth_post_processors.py
index 82ee9f1c..1a68c716 100644
--- a/ada_feeding_perception/ada_feeding_perception/depth_post_processors.py
+++ b/ada_feeding_perception/ada_feeding_perception/depth_post_processors.py
@@ -7,11 +7,13 @@
from typing import Callable
# Third-party imports
+from builtin_interfaces.msg import Time
import cv2 as cv
from cv_bridge import CvBridge
import numpy as np
import numpy.typing as npt
from sensor_msgs.msg import Image
+from std_msgs.msg import Header
def create_mask_post_processor(
@@ -58,7 +60,10 @@ def mask_post_processor(msg: Image) -> Image:
# Get the new img message
masked_msg = bridge.cv2_to_imgmsg(masked_img)
- masked_msg.header = msg.header
+ masked_msg.header = Header(
+ stamp=Time(sec=msg.header.stamp.sec, nanosec=msg.header.stamp.nanosec),
+ frame_id=msg.header.frame_id,
+ )
return masked_msg
@@ -124,7 +129,10 @@ def temporal_post_processor(msg: Image) -> Image:
# Get the new img message
masked_msg = bridge.cv2_to_imgmsg(masked_img)
- masked_msg.header = msg.header
+ masked_msg.header = Header(
+ stamp=Time(sec=msg.header.stamp.sec, nanosec=msg.header.stamp.nanosec),
+ frame_id=msg.header.frame_id,
+ )
return masked_msg
@@ -176,7 +184,10 @@ def spatial_post_processor(msg: Image) -> Image:
# Get the new img message
masked_msg = bridge.cv2_to_imgmsg(masked_img)
- masked_msg.header = msg.header
+ masked_msg.header = Header(
+ stamp=Time(sec=msg.header.stamp.sec, nanosec=msg.header.stamp.nanosec),
+ frame_id=msg.header.frame_id,
+ )
return masked_msg
@@ -234,7 +245,10 @@ def threshold_post_processor(msg: Image) -> Image:
# Get the new img message
masked_msg = bridge.cv2_to_imgmsg(masked_img)
- masked_msg.header = msg.header
+ masked_msg.header = Header(
+ stamp=Time(sec=msg.header.stamp.sec, nanosec=msg.header.stamp.nanosec),
+ frame_id=msg.header.frame_id,
+ )
return masked_msg
diff --git a/ada_feeding_perception/ada_feeding_perception/face_detection.py b/ada_feeding_perception/ada_feeding_perception/face_detection.py
index 5f822457..86f3f7f8 100755
--- a/ada_feeding_perception/ada_feeding_perception/face_detection.py
+++ b/ada_feeding_perception/ada_feeding_perception/face_detection.py
@@ -56,6 +56,8 @@ class FaceDetectionNode(Node):
let the client decide which face to use.
"""
+ # pylint: disable=duplicate-code
+ # Much of the logic of this node mirrors FoodOnForkDetection. This is fine.
# pylint: disable=too-many-instance-attributes
# Needed for multiple model loads, publisher, subscribers, and shared variables
def __init__(
@@ -305,10 +307,6 @@ def toggle_face_detection_callback(
the face detection on or off depending on the request.
"""
- # pylint: disable=duplicate-code
- # We follow similar logic in any service to toggle a node
- # (e.g., face detection)
-
self.get_logger().info(f"Incoming service request. data: {request.data}")
response.success = False
response.message = f"Failed to set is_on to {request.data}"
@@ -563,6 +561,7 @@ def get_mouth_depth(
f"Corresponding RGB image message received at {rgb_msg.header.stamp}. "
f"Time difference: {min_time_diff} seconds."
)
+ # TODO: This should use the ros_msg_to_cv2_image helper function
image_depth = self.bridge.imgmsg_to_cv2(
closest_depth_msg,
desired_encoding="passthrough",
@@ -651,6 +650,7 @@ def run(self) -> None:
continue
# Detect the largest face in the RGB image
+ # TODO: This should use the ros_msg_to_cv2_image helper function
image_bgr = cv2.imdecode(
np.frombuffer(rgb_msg.data, np.uint8), cv2.IMREAD_COLOR
)
diff --git a/ada_feeding_perception/ada_feeding_perception/food_on_fork_detection.py b/ada_feeding_perception/ada_feeding_perception/food_on_fork_detection.py
new file mode 100644
index 00000000..34e14c96
--- /dev/null
+++ b/ada_feeding_perception/ada_feeding_perception/food_on_fork_detection.py
@@ -0,0 +1,716 @@
+"""
+This module contains a ROS2 node that: (a) takes in parameters specifying a FoodOnFork
+class to use and kwargs for the class's constructor; (b) exposes a ROS2 service to
+toggle the perception algorithm on and off; and (c) when the perception algorithm is
+on, subscribes to the depth image topic and publishes the confidence that there is food
+on the fork.
+"""
+# Standard imports
+import collections
+import os
+import threading
+from typing import Any, Dict, Tuple
+
+# Third-party imports
+from cv_bridge import CvBridge
+import cv2
+import numpy as np
+import numpy.typing as npt
+from rcl_interfaces.msg import ParameterDescriptor, ParameterType
+import rclpy
+from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
+from rclpy.executors import MultiThreadedExecutor
+from rclpy.node import Node
+from sensor_msgs.msg import CameraInfo, CompressedImage, Image
+from std_srvs.srv import SetBool
+from tf2_ros.buffer import Buffer
+from tf2_ros.transform_listener import TransformListener
+
+# Local imports
+from ada_feeding.helpers import import_from_string
+from ada_feeding_msgs.msg import FoodOnForkDetection
+from ada_feeding_perception.food_on_fork_detectors import FoodOnForkDetector
+from ada_feeding_perception.helpers import (
+ cv2_image_to_ros_msg,
+ get_img_msg_type,
+ ros_msg_to_cv2_image,
+)
+from .depth_post_processors import (
+ create_spatial_post_processor,
+ create_temporal_post_processor,
+)
+
+
+class FoodOnForkDetectionNode(Node):
+ """
+ A ROS2 node that takes in parameters specifying a FoodOnForkDetector class to use and
+ kwargs for the class's constructor, exposes a ROS2 service to toggle the perception
+ algorithm on and off, and when the perception algorithm is on, subscribes to the
+ depth image topic and publishes the confidence that there is food on the fork.
+ """
+
+ # pylint: disable=duplicate-code
+ # Much of the logic of this node mirrors FaceDetection. This is fine.
+ # pylint: disable=too-many-instance-attributes
+ # Needed for multiple publishers/subscribers, model parameters, etc.
+ def __init__(self):
+ """
+ Initializes the FoodOnForkDetection.
+ """
+ super().__init__("food_on_fork_detection")
+
+ # Load the parameters
+ (
+ model_class,
+ model_path,
+ model_dir,
+ model_kwargs,
+ self.rate_hz,
+ self.crop_top_left,
+ self.crop_bottom_right,
+ self.depth_min_mm,
+ self.depth_max_mm,
+ temporal_window_size,
+ spatial_num_pixels,
+ self.viz,
+ self.viz_upper_thresh,
+ self.viz_lower_thresh,
+ rgb_image_buffer,
+ ) = self.read_params()
+
+ # Create the post-processors
+ self.cv_bridge = CvBridge()
+ self.post_processors = [
+ create_temporal_post_processor(temporal_window_size, self.cv_bridge),
+ create_spatial_post_processor(spatial_num_pixels, self.cv_bridge),
+ ]
+
+ # Construct the FoodOnForkDetector model
+ food_on_fork_class = import_from_string(model_class)
+ assert issubclass(
+ food_on_fork_class, FoodOnForkDetector
+ ), f"Model {model_class} must subclass FoodOnForkDetector"
+ self.model = food_on_fork_class(**model_kwargs)
+ self.model.crop_top_left = self.crop_top_left
+ self.model.crop_bottom_right = self.crop_bottom_right
+ if len(model_path) > 0:
+ self.model.load(os.path.join(model_dir, model_path))
+
+ # Create the TF buffer, in case the perception algorithm needs it
+ self.tf_buffer = Buffer()
+ self.tf_listener = TransformListener(self.tf_buffer, self)
+ self.model.tf_buffer = self.tf_buffer
+
+ # Create the service to toggle the perception algorithm on and off
+ self.is_on = False
+ self.is_on_lock = threading.Lock()
+ self.srv = self.create_service(
+ SetBool,
+ "~/toggle_food_on_fork_detection",
+ self.toggle_food_on_fork_detection,
+ callback_group=MutuallyExclusiveCallbackGroup(),
+ )
+
+ # Create the publisher
+ self.pub = self.create_publisher(
+ FoodOnForkDetection, "~/food_on_fork_detection", 1
+ )
+
+ # Create the CameraInfo subscribers
+ self.camera_info_sub = self.create_subscription(
+ CameraInfo,
+ "~/camera_info",
+ self.camera_info_callback,
+ 1,
+ callback_group=MutuallyExclusiveCallbackGroup(),
+ )
+
+ # Create the depth image subscriber
+ self.depth_img = None
+ self.depth_img_lock = threading.Lock()
+ aligned_depth_topic = "~/aligned_depth"
+ try:
+ aligned_depth_type = get_img_msg_type(aligned_depth_topic, self)
+ except ValueError as err:
+ self.get_logger().error(
+ f"Error getting type of depth image topic. Defaulting to Image. {err}"
+ )
+ aligned_depth_type = Image
+ # Subscribe to the depth image
+ self.depth_subscription = self.create_subscription(
+ aligned_depth_type,
+ aligned_depth_topic,
+ self.depth_callback,
+ 1,
+ callback_group=MutuallyExclusiveCallbackGroup(),
+ )
+
+ # If the visualization flag is set, create a subscriber to the RGB image
+ # and publisher for the RGB visualization
+ if self.viz:
+ self.rgb_pub = self.create_publisher(
+ Image, "~/food_on_fork_detection_img", 1
+ )
+ self.img_buffer = collections.deque(maxlen=rgb_image_buffer)
+ self.img_buffer_lock = threading.Lock()
+ image_topic = "~/image"
+ try:
+ image_type = get_img_msg_type(image_topic, self)
+ except ValueError as err:
+ self.get_logger().error(
+ f"Error getting type of image topic. Defaulting to CompressedImage. {err}"
+ )
+ image_type = CompressedImage
+ self.img_subscription = self.create_subscription(
+ image_type,
+ image_topic,
+ self.camera_callback,
+ 1,
+ callback_group=MutuallyExclusiveCallbackGroup(),
+ )
+
+ def read_params(
+ self,
+ ) -> Tuple[
+ str,
+ str,
+ str,
+ Dict[str, Any],
+ float,
+ Tuple[int, int],
+ Tuple[int, int],
+ int,
+ int,
+ int,
+ int,
+ bool,
+ float,
+ float,
+ int,
+ ]:
+ """
+ Reads the parameters for the FoodOnForkDetection.
+
+ Returns
+ -------
+ model_class: The FoodOnFork class to use. Must be a subclass of FoodOnFork.
+ model_path: The path to the model file. This must be relative to the model_dir
+ parameter. Ignored if the empty string.
+ model_dir: The directory to load the model from.
+ model_kwargs: The keywords to pass to the FoodOnFork class's constructor.
+ rate_hz: The rate (Hz) at which to publish.
+ crop_top_left: The top left corner of the crop box.
+ crop_bottom_right: The bottom right corner of the crop box.
+ depth_min_mm: The minimum depth (mm) to consider.
+ depth_max_mm: The maximum depth (mm) to consider.
+ temporal_window_size: The size of the temporal window for post-processing.
+ Disabled by default.
+ spatial_num_pixels: The number of pixels for the spatial post-processor.
+ Disabled by default.
+ viz: Whether to publish a visualization of the result as an RGB image.
+ viz_upper_thresh: The upper threshold for declaring FoF in the viz.
+ viz_lower_thresh: The lower threshold for declaring FoF in the viz.
+ rgb_image_buffer: The number of RGB images to store at a time for visualization.
+ """
+ # pylint: disable=too-many-locals
+ # There are many parameters to load.
+
+ # Read the model_class
+ model_class = self.declare_parameter(
+ "model_class",
+ descriptor=ParameterDescriptor(
+ name="model_class",
+ type=ParameterType.PARAMETER_STRING,
+ description=(
+ "The FoodOnFork class to use. Must be a subclass of FoodOnFork."
+ ),
+ read_only=True,
+ ),
+ )
+ model_class = model_class.value
+
+ # Read the model_path
+ model_path = self.declare_parameter(
+ "model_path",
+ descriptor=ParameterDescriptor(
+ name="model_path",
+ type=ParameterType.PARAMETER_STRING,
+ description=(
+ "The path to the model file. This must be relative to the "
+ "model_dir parameter. Ignored if the empty string."
+ ),
+ read_only=True,
+ ),
+ )
+ model_path = model_path.value
+
+ # Read the model_dir
+ model_dir = self.declare_parameter(
+ "model_dir",
+ descriptor=ParameterDescriptor(
+ name="model_dir",
+ type=ParameterType.PARAMETER_STRING,
+ description=("The directory to load the model from."),
+ read_only=True,
+ ),
+ )
+ model_dir = model_dir.value
+
+ # Read the model_kwargs
+ model_kwargs = {}
+ model_kws = self.declare_parameter(
+ "model_kws",
+ descriptor=ParameterDescriptor(
+ name="model_kws",
+ type=ParameterType.PARAMETER_STRING_ARRAY,
+ description=(
+ "The keywords to pass to the FoodOnFork class's constructor."
+ ),
+ read_only=True,
+ ),
+ )
+ for kw in model_kws.value:
+ full_name = f"model_kwargs.{kw}"
+ arg = self.declare_parameter(
+ full_name,
+ descriptor=ParameterDescriptor(
+ name=kw,
+ description="Custom keyword argument for the model.",
+ dynamic_typing=True,
+ read_only=True,
+ ),
+ )
+ if isinstance(arg, collections.abc.Sequence):
+ arg = list(arg.value)
+ else:
+ arg = arg.value
+ model_kwargs[kw] = arg
+
+ # Get the rate at which to operate
+ rate_hz = self.declare_parameter(
+ "rate_hz",
+ 10.0,
+ descriptor=ParameterDescriptor(
+ name="rate_hz",
+ type=ParameterType.PARAMETER_DOUBLE,
+ description="The rate (Hz) at which to publish.",
+ read_only=True,
+ ),
+ )
+ rate_hz = rate_hz.value
+
+ # Get the crop box
+ crop_top_left = self.declare_parameter(
+ "crop_top_left",
+ (0, 0),
+ descriptor=ParameterDescriptor(
+ name="crop_top_left",
+ type=ParameterType.PARAMETER_INTEGER_ARRAY,
+ description="The top left corner of the crop box.",
+ read_only=True,
+ ),
+ )
+ crop_top_left = crop_top_left.value
+ crop_bottom_right = self.declare_parameter(
+ "crop_bottom_right",
+ (0, 0),
+ descriptor=ParameterDescriptor(
+ name="crop_bottom_right",
+ type=ParameterType.PARAMETER_INTEGER_ARRAY,
+ description="The bottom right corner of the crop box.",
+ read_only=True,
+ ),
+ )
+ crop_bottom_right = crop_bottom_right.value
+
+ # Get the depth range
+ depth_min_mm = self.declare_parameter(
+ "depth_min_mm",
+ 0,
+ descriptor=ParameterDescriptor(
+ name="depth_min_mm",
+ type=ParameterType.PARAMETER_INTEGER,
+ description="The minimum depth (mm) to consider.",
+ read_only=True,
+ ),
+ )
+ depth_min_mm = depth_min_mm.value
+ depth_max_mm = self.declare_parameter(
+ "depth_max_mm",
+ 20000,
+ descriptor=ParameterDescriptor(
+ name="depth_max_mm",
+ type=ParameterType.PARAMETER_INTEGER,
+ description="The maximum depth (mm) to consider.",
+ read_only=True,
+ ),
+ )
+ depth_max_mm = depth_max_mm.value
+
+ # Configure the post-processors
+ temporal_window_size = self.declare_parameter(
+ "temporal_window_size",
+ 1,
+ descriptor=ParameterDescriptor(
+ name="temporal_window_size",
+ type=ParameterType.PARAMETER_INTEGER,
+ description="The size of the temporal window for post-processing. Disabled by default.",
+ read_only=True,
+ ),
+ )
+ temporal_window_size = temporal_window_size.value
+ spatial_num_pixels = self.declare_parameter(
+ "spatial_num_pixels",
+ 1,
+ descriptor=ParameterDescriptor(
+ name="spatial_num_pixels",
+ type=ParameterType.PARAMETER_INTEGER,
+ description="The number of pixels for the spatial post-processor. Disabled by default.",
+ read_only=True,
+ ),
+ )
+ spatial_num_pixels = spatial_num_pixels.value
+
+ # Get the visualization parameters
+ viz = self.declare_parameter(
+ "viz",
+ False,
+ descriptor=ParameterDescriptor(
+ name="viz",
+ type=ParameterType.PARAMETER_BOOL,
+ description="Whether to publish a visualization of the result as an RGB image.",
+ read_only=True,
+ ),
+ )
+ viz = viz.value
+ viz_upper_thresh = self.declare_parameter(
+ "viz_upper_thresh",
+ 0.5,
+ descriptor=ParameterDescriptor(
+ name="viz_upper_thresh",
+ type=ParameterType.PARAMETER_DOUBLE,
+ description="The upper threshold for declaring FoF in the viz.",
+ read_only=True,
+ ),
+ )
+ viz_upper_thresh = viz_upper_thresh.value
+ viz_lower_thresh = self.declare_parameter(
+ "viz_lower_thresh",
+ 0.5,
+ descriptor=ParameterDescriptor(
+ name="viz_lower_thresh",
+ type=ParameterType.PARAMETER_DOUBLE,
+ description="The lower threshold for declaring FoF in the viz.",
+ read_only=True,
+ ),
+ )
+ viz_lower_thresh = viz_lower_thresh.value
+ rgb_image_buffer = self.declare_parameter(
+ "rgb_image_buffer",
+ 30,
+ descriptor=ParameterDescriptor(
+ name="rgb_image_buffer",
+ type=ParameterType.PARAMETER_INTEGER,
+ description=(
+ "The number of RGB images to store at a time for visualization. Default: 30"
+ ),
+ read_only=True,
+ ),
+ )
+ rgb_image_buffer = rgb_image_buffer.value
+
+ return (
+ model_class,
+ model_path,
+ model_dir,
+ model_kwargs,
+ rate_hz,
+ crop_top_left,
+ crop_bottom_right,
+ depth_min_mm,
+ depth_max_mm,
+ temporal_window_size,
+ spatial_num_pixels,
+ viz,
+ viz_upper_thresh,
+ viz_lower_thresh,
+ rgb_image_buffer,
+ )
+
+ def toggle_food_on_fork_detection(
+ self, request: SetBool.Request, response: SetBool.Response
+ ) -> SetBool.Response:
+ """
+ Toggles the perception algorithm on and off.
+
+ Parameters
+ ----------
+ request: The request to toggle the perception algorithm on and off.
+ response: The response to toggle the perception algorithm on and off.
+
+ Returns
+ -------
+ response: The response to toggle the perception algorithm on and off.
+ """
+
+ self.get_logger().info(f"Incoming service request. data: {request.data}")
+ response.success = False
+ response.message = f"Failed to set is_on to {request.data}"
+ with self.is_on_lock:
+ self.is_on = request.data
+ response.success = True
+ response.message = f"Successfully set is_on to {request.data}"
+ return response
+
+ def camera_info_callback(self, msg: CameraInfo) -> None:
+ """
+ Callback for the camera info. Note that we assume CameraInfo never
+ changes, and therefore destroy the subscriber after the first message.
+
+ Parameters
+ ----------
+ msg: The camera info message.
+ """
+ if self.model.camera_info is None:
+ self.model.camera_info = msg
+ self.destroy_subscription(self.camera_info_sub)
+
+ def depth_callback(self, msg: Image) -> None:
+ """
+ Callback for the depth image.
+
+ Parameters
+ ----------
+ msg: The depth image message.
+ """
+ for post_processor in self.post_processors:
+ msg = post_processor(msg)
+ with self.depth_img_lock:
+ self.depth_img = msg
+
+ def camera_callback(self, msg: Image) -> None:
+ """
+ Callback for the camera image.
+
+ Parameters
+ ----------
+ msg: The camera image message.
+ """
+ with self.img_buffer_lock:
+ self.img_buffer.append(msg)
+
+ def visualize_result(
+ self, result: FoodOnForkDetection, t: npt.NDArray, debug: bool = True
+ ) -> None:
+ """
+ Annotates the nearest RGB image message with the result and publishes it.
+
+ Parameters
+ ----------
+ result: The result of the food on fork detection.
+ t: The transform(s) used in the detection. Size (N, 4, 4) where N is the
+ number of transforms.
+ debug: Whether to overlay additional debug information on the image.
+ """
+ # Get the RGB image with timestamp closest to the depth image
+ with self.img_buffer_lock:
+ img_msg = None
+ # At the end of this for loop, img_message will be the most
+ # recent image that is older than the depth image, or the
+ # oldest image if there are no images older than the depth
+ # image.
+ for i, img_msg in enumerate(self.img_buffer):
+ img_msg_stamp = (
+ img_msg.header.stamp.sec + img_msg.header.stamp.nanosec * 1e-9
+ )
+ result_stamp = (
+ result.header.stamp.sec + result.header.stamp.nanosec * 1e-9
+ )
+ if img_msg_stamp > result_stamp:
+ if i > 0:
+ img_msg = self.img_buffer[i - 1]
+ break
+ # If img_msg is None, that means we haven't received an RGB image yet
+ if img_msg is None:
+ return
+
+ # Convert the RGB image to a cv2 image
+ img_cv2 = ros_msg_to_cv2_image(img_msg, self.cv_bridge)
+
+ # Allow the model to overlay additional debug information on the image
+ if debug:
+ img_cv2 = self.model.overlay_debug_info(img_cv2, t)
+
+ # Get the message to write on the image
+ proba = result.probability
+ status = result.status
+ if proba > self.viz_upper_thresh:
+ pred = "Food on Fork"
+ color = (0, 255, 0)
+ elif (
+ proba <= self.viz_lower_thresh
+ or status == FoodOnForkDetection.ERROR_TOO_FEW_POINTS
+ ):
+ pred = "No Food on Fork"
+ color = (0, 0, 255)
+ elif status == FoodOnForkDetection.SUCCESS:
+ pred = "Uncertain (Ask User)"
+ color = (255, 0, 0)
+ else:
+ pred = "Unknown Error"
+ color = (255, 255, 255)
+ msg = f"{pred}: {proba:.2f}"
+
+ # Write the message on the top-left corner of the image
+ cv2.putText(
+ img_cv2,
+ msg,
+ (10, 30),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 1,
+ color,
+ 2,
+ cv2.LINE_AA,
+ )
+
+ # Add a rectangular border around the image in the specified color
+ cv2.rectangle(
+ img_cv2,
+ (0, 0),
+ (img_cv2.shape[1], img_cv2.shape[0]),
+ color,
+ 10,
+ )
+
+ # Also add a rectangle around the crop box
+ cv2.rectangle(
+ img_cv2,
+ self.crop_top_left,
+ self.crop_bottom_right,
+ color,
+ 2,
+ )
+
+ # Publish the image
+ self.rgb_pub.publish(
+ cv2_image_to_ros_msg(img_cv2, compress=False, bridge=self.cv_bridge)
+ )
+
+ def run(self) -> None:
+ """
+ Runs the FoodOnForkDetection.
+ """
+ rate = self.create_rate(self.rate_hz)
+ while rclpy.ok():
+ # Loop at the specified rate
+ rate.sleep()
+
+ # Check if food on fork detection is on
+ with self.is_on_lock:
+ is_on = self.is_on
+ if not is_on:
+ continue
+
+ # Create the FoodOnForkDetection message
+ food_on_fork_detection_msg = FoodOnForkDetection()
+
+ # Get the latest depth image
+ with self.depth_img_lock:
+ depth_img_msg = self.depth_img
+ self.depth_img = None
+ if depth_img_msg is None:
+ continue
+ food_on_fork_detection_msg.header = depth_img_msg.header
+
+ # Convert the depth image to a cv2 image, crop it, and remove depth
+ # values outside the range of interest
+ depth_img_cv2 = ros_msg_to_cv2_image(depth_img_msg, self.cv_bridge)
+ depth_img_cv2 = depth_img_cv2[
+ self.crop_top_left[1] : self.crop_bottom_right[1],
+ self.crop_top_left[0] : self.crop_bottom_right[0],
+ ]
+ depth_img_cv2 = np.where(
+ (depth_img_cv2 >= self.depth_min_mm)
+ & (depth_img_cv2 <= self.depth_max_mm),
+ depth_img_cv2,
+ 0,
+ )
+ X = np.expand_dims(depth_img_cv2, axis=0)
+
+ # Get the desired transform(s)
+ transforms = FoodOnForkDetector.get_transforms(
+ self.model.transform_frames,
+ self.tf_buffer,
+ )
+ t = np.expand_dims(transforms, 0)
+
+ # Get the probability that there is food on the fork
+ try:
+ proba, status = self.model.predict_proba(X, t)
+ proba = proba[0]
+ status = int(status[0])
+ food_on_fork_detection_msg.probability = proba
+ food_on_fork_detection_msg.status = status
+ if status == FoodOnForkDetection.SUCCESS:
+ food_on_fork_detection_msg.message = "No errors."
+ elif status == FoodOnForkDetection.ERROR_TOO_FEW_POINTS:
+ food_on_fork_detection_msg.message = (
+ "Error: Too few detected points. This typically means there is "
+ "no food on the fork."
+ )
+ elif status == FoodOnForkDetection.ERROR_NO_TRANSFORM:
+ food_on_fork_detection_msg.message = (
+ "Error: Could not get requested transform(s)."
+ )
+ # pylint: disable=broad-except
+ # This is necessary because we don't know what exceptions the model
+ # might raise.
+ except Exception as err:
+ err_str = f"Error predicting food on fork: {err}"
+ self.get_logger().error(err_str)
+ food_on_fork_detection_msg.probability = np.nan
+ food_on_fork_detection_msg.status = FoodOnForkDetection.UNKNOWN_ERROR
+ food_on_fork_detection_msg.message = err_str
+
+ # Visualize the results
+ if self.viz:
+ self.visualize_result(food_on_fork_detection_msg, t[0])
+
+ # Publish the FoodOnForkDetection message
+ self.pub.publish(food_on_fork_detection_msg)
+
+
+def main(args=None):
+ """
+ Launch the ROS node and spin.
+ """
+ rclpy.init(args=args)
+
+ food_on_fork_detection = FoodOnForkDetectionNode()
+ executor = MultiThreadedExecutor(num_threads=4)
+
+ # Spin in the background since detecting faces will block
+ # the main thread
+ spin_thread = threading.Thread(
+ target=rclpy.spin,
+ args=(food_on_fork_detection,),
+ kwargs={"executor": executor},
+ daemon=True,
+ )
+ spin_thread.start()
+
+ # Run face detection
+ try:
+ food_on_fork_detection.run()
+ except KeyboardInterrupt:
+ pass
+
+ # Terminate this node
+ food_on_fork_detection.destroy_node()
+ rclpy.shutdown()
+ # Join the spin thread (so it is spinning in the main thread)
+ spin_thread.join()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ada_feeding_perception/ada_feeding_perception/food_on_fork_detectors.py b/ada_feeding_perception/ada_feeding_perception/food_on_fork_detectors.py
new file mode 100644
index 00000000..626de1a2
--- /dev/null
+++ b/ada_feeding_perception/ada_feeding_perception/food_on_fork_detectors.py
@@ -0,0 +1,942 @@
+"""
+This file contains an abstract class, FoodOnForkDetector, that takes in a single depth
+image and returns a confidence in [0,1] that there is food on the fork.
+"""
+# Standard imports
+from abc import ABC, abstractmethod
+from enum import Enum
+import os
+import time
+from typing import List, Optional, Tuple
+
+# Third-party imports
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import numpy.typing as npt
+from overrides import override
+import rclpy
+from sensor_msgs.msg import CameraInfo
+from sklearn.linear_model import LogisticRegression
+from sklearn.metrics import f1_score
+from sklearn.metrics.pairwise import pairwise_distances
+from sklearn.model_selection import train_test_split
+import tf2_ros
+from tf2_ros.buffer import Buffer
+from transforms3d._gohlketransforms import quaternion_matrix
+
+# Local imports
+from ada_feeding_msgs.msg import FoodOnForkDetection
+from ada_feeding_perception.helpers import (
+ depth_img_to_pointcloud,
+ show_3d_scatterplot,
+ show_normalized_depth_img,
+)
+
+
+class FoodOnForkLabel(Enum):
+ """
+ An enumeration of possible labels for food on the fork.
+ """
+
+ NO_FOOD = 0
+ FOOD = 1
+ UNSURE = 2
+
+
+class FoodOnForkDetector(ABC):
+ """
+ An abstract class for any perception algorithm that takes in a single depth
+ image and returns a confidence in [0,1] that there is food on the fork.
+ """
+
+ def __init__(self, verbose: bool = False) -> None:
+ """
+ Initializes the perception algorithm.
+
+ Parameters
+ ----------
+ verbose: Whether to print debug messages.
+ """
+ self.__camera_info = None
+ self.__crop_top_left = (0, 0)
+ self.__crop_bottom_right = (640, 480)
+ self.__seed = int(time.time() * 1000)
+ self.verbose = verbose
+
+ @property
+ def camera_info(self) -> Optional[CameraInfo]:
+ """
+ The camera info for the depth image.
+
+ Returns
+ -------
+ camera_info: The camera info for the depth image, or None if not set.
+ """
+ return self.__camera_info
+
+ @camera_info.setter
+ def camera_info(self, camera_info: CameraInfo) -> None:
+ """
+ Sets the camera info for the depth image.
+
+ Parameters
+ ----------
+ camera_info: The camera info for the depth image.
+ """
+ self.__camera_info = camera_info
+
+ @property
+ def crop_top_left(self) -> Tuple[int, int]:
+ """
+ The top left corner of the region of interest in the depth image.
+
+ Returns
+ -------
+ crop_top_left: The top left corner of the region of interest in the depth
+ image.
+ """
+ return self.__crop_top_left
+
+ @crop_top_left.setter
+ def crop_top_left(self, crop_top_left: Tuple[int, int]) -> None:
+ """
+ Sets the top left corner of the region of interest in the depth image.
+
+ Parameters
+ ----------
+ crop_top_left: The top left corner of the region of interest in the depth
+ image.
+ """
+ self.__crop_top_left = crop_top_left
+
+ @property
+ def crop_bottom_right(self) -> Tuple[int, int]:
+ """
+ The bottom right corner of the region of interest in the depth image.
+
+ Returns
+ -------
+ crop_bottom_right: The bottom right corner of the region of interest in
+ the depth image.
+ """
+ return self.__crop_bottom_right
+
+ @crop_bottom_right.setter
+ def crop_bottom_right(self, crop_bottom_right: Tuple[int, int]) -> None:
+ """
+ Sets the bottom right corner of the region of interest in the depth image.
+
+ Parameters
+ ----------
+ crop_bottom_right: The bottom right corner of the region of interest in
+ the depth image.
+ """
+ self.__crop_bottom_right = crop_bottom_right
+
+ @property
+ def seed(self) -> int:
+ """
+ The random seed to use in the detector.
+
+ Returns
+ -------
+ seed: The random seed to use in the detector.
+ """
+ return self.__seed
+
+ @seed.setter
+ def seed(self, seed: int) -> None:
+ """
+ Sets the random seed to use in the detector.
+
+ Parameters
+ ----------
+ seed: The random seed to use in the detector.
+ """
+ self.__seed = seed
+
+ @property
+ def transform_frames(self) -> List[Tuple[str, str]]:
+ """
+ Gets the parent and child frame for every transform that this classifier
+ wants to use.
+
+ Returns
+ -------
+ frames: A list of (parent_frame_id, child_frame_id) tuples.
+ """
+ return []
+
+ @staticmethod
+ def get_transforms(frames: List[Tuple[str, str]], tf_buffer: Buffer) -> npt.NDArray:
+ """
+ Gets the most recent transforms that are necessary for this classifier.
+ These are then passed into fit, predict_proba, and predict.
+
+ Parameters
+ ----------
+ frames: A list of (parent_frame_id, child_frame_id) tuples to get transforms
+ for. Size: (num_transforms, 2).
+ tf_buffer: The tf buffer that stores the transforms.
+
+ Returns
+ -------
+ transforms: The transforms (homogenous coordinates) that are necessary
+ for this classifier. Size (num_transforms, 4, 4). Note that if the
+ transform is not found, it will be a zero matrix.
+ """
+ transforms = []
+ for parent_frame_id, child_frame_id in frames:
+ try:
+ transform = tf_buffer.lookup_transform(
+ parent_frame_id,
+ child_frame_id,
+ rclpy.time.Time(),
+ )
+ # Convert the transform into a matrix
+ M = quaternion_matrix(
+ [
+ transform.transform.rotation.w,
+ transform.transform.rotation.x,
+ transform.transform.rotation.y,
+ transform.transform.rotation.z,
+ ],
+ )
+ M[:3, 3] = [
+ transform.transform.translation.x,
+ transform.transform.translation.y,
+ transform.transform.translation.z,
+ ]
+ transforms.append(M)
+ except (
+ tf2_ros.LookupException,
+ tf2_ros.ConnectivityException,
+ tf2_ros.ExtrapolationException,
+ ) as err:
+ print(
+ f"Error getting transform from {parent_frame_id} to {child_frame_id}: {err}"
+ )
+ transforms.append(np.zeros((4, 4), dtype=float))
+
+ return np.array(transforms)
+
+ @abstractmethod
+ def fit(
+ self,
+ X: npt.NDArray,
+ y: npt.NDArray[int],
+ t: npt.NDArray[float],
+ viz_save_dir: Optional[str] = None,
+ ) -> None:
+ """
+ Trains the perception algorithm on a dataset of depth images and
+ corresponding labels.
+
+ Parameters
+ ----------
+ X: The depth images to train on. Size (num_images, height, width).
+ y: The labels for the depth images. Size (num_images,). Must be one of the
+ values enumerated in FoodOnForkLabel.
+ t: The transforms (homogenous coordinates) that are necessary for this
+ classifier. Size (num_images, num_transforms, 4, 4). Should be outputted
+ by `get_transforms`.
+ viz_save_dir: The directory to save visualizations to. If None, no
+ visualizations will be saved.
+ """
+
+ @abstractmethod
+ def save(self, path: str) -> str:
+ """
+ Saves the model to a file.
+
+ Parameters
+ ----------
+ path: The path to save the perception algorithm to. This file should not
+ have an extension; this function will add the appropriate extension.
+
+ Returns
+ -------
+ save_path: The path that the model was saved to.
+ """
+
+ @abstractmethod
+ def load(self, path: str) -> None:
+ """
+ Loads the model a file.
+
+ Parameters
+ ----------
+ path: The path to load the perception algorithm from. If the path does
+ not have an extension, this function will add the appropriate
+ extension.
+ """
+
+ @abstractmethod
+ def predict_proba(
+ self,
+ X: npt.NDArray,
+ t: npt.NDArray[float],
+ ) -> Tuple[npt.NDArray[float], npt.NDArray[int]]:
+ """
+ Predicts the probability that there is food on the fork for a set of
+ depth images.
+
+ Parameters
+ ----------
+ X: The depth images to predict on.
+ t: The transforms (homogenous coordinates) that are necessary for this
+ classifier. Size (num_images, num_transforms, 4, 4). Should be outputted
+ by `get_transforms`.
+
+ Returns
+ -------
+ y: The predicted probabilities that there is food on the fork.
+ statuses: The status of each prediction. Must be one of the const values
+ declared in the FoodOnForkDetection message.
+ """
+
+ def predict(
+ self,
+ X: npt.NDArray,
+ t: npt.NDArray[float],
+ lower_thresh: float,
+ upper_thresh: float,
+ proba: Optional[npt.NDArray] = None,
+ statuses: Optional[npt.NDArray[int]] = None,
+ ) -> Tuple[npt.NDArray[int], npt.NDArray[int]]:
+ """
+ Predicts whether there is food on the fork for a set of depth images.
+
+ Parameters
+ ----------
+ X: The depth images to predict on.
+ t: The transforms (homogenous coordinates) that are necessary for this
+ classifier. Size (num_images, num_transforms, 4, 4). Should be outputted
+ by `get_transforms`.
+ lower_thresh: The lower threshold for food on the fork.
+ upper_thresh: The upper threshold for food on the fork.
+ proba: The predicted probabilities that there is food on the fork. If either
+ proba or statuses is None, this function will call predict_proba to get
+ the proba and statuses.
+ statuses: The status of each prediction. Must be one of the const values
+ declared in the FoodOnForkDetection message. If either proba or statuses
+ is None, this function will call predict_proba to get the proba and
+ statuses.
+
+ Returns
+ -------
+ y: The predicted labels for whether there is food on the fork. Must be one
+ of the values enumerated in FoodOnForkLabel.
+ statuses: The status of each prediction. Must be one of the const values
+ declared in the FoodOnForkDetection message.
+ """
+ # pylint: disable=too-many-arguments
+ # These many are fine.
+ if proba is None or statuses is None:
+ proba, statuses = self.predict_proba(X, t)
+ return (
+ np.where(
+ proba < lower_thresh,
+ FoodOnForkLabel.NO_FOOD.value,
+ np.where(
+ proba > upper_thresh,
+ FoodOnForkLabel.FOOD.value,
+ FoodOnForkLabel.UNSURE.value,
+ ),
+ ),
+ statuses,
+ )
+
+ def overlay_debug_info(self, img: npt.NDArray, t: npt.NDArray) -> npt.NDArray:
+ """
+ Overlays debug information onto a depth image.
+
+ Parameters
+ ----------
+ img: The depth image to overlay debug information onto.
+ t: The closest transforms (homogenous coordinates) to this image's timestamp.
+ Size (num_transforms, 4, 4). Should be outputted by `get_transforms`.
+
+ Returns
+ -------
+ img_with_debug_info: The depth image with debug information overlayed.
+ """
+ # pylint: disable=unused-argument
+ return img
+
+ def visualize_img(self, img: npt.NDArray, t: npt.NDArray) -> None:
+ """
+ Visualizes a depth image. This function is used for debugging, so it helps
+ to not only visualize the img, but also subclass-specific information that
+ can help explain why the img would result in a particular prediction.
+
+ It is acceptable for this function to block until the user closes a window.
+
+ Parameters
+ ----------
+ img: The depth image to visualize.
+ t: The closest transforms (homogenous coordinates) to this image's timestamp.
+ Size (num_transforms, 4, 4). Should be outputted by `get_transforms`.
+ """
+ # pylint: disable=unused-argument
+ show_normalized_depth_img(img, wait=True, window_name="img")
+
+
+class FoodOnForkDummyDetector(FoodOnForkDetector):
+ """
+ A dummy perception algorithm that always predicts the same probability.
+ """
+
+ def __init__(self, proba: float, verbose: bool = False) -> None:
+ """
+ Initializes the dummy perception algorithm.
+
+ Parameters
+ ----------
+ proba: The probability that the dummy algorithm should always predict.
+ verbose: Whether to print debug messages.
+ """
+ super().__init__(verbose)
+ self.proba = proba
+
+ @override
+ def fit(
+ self,
+ X: npt.NDArray,
+ y: npt.NDArray[int],
+ t: npt.NDArray[float],
+ viz_save_dir: Optional[str] = None,
+ ) -> None:
+ pass
+
+ @override
+ def save(self, path: str) -> str:
+ return ""
+
+ @override
+ def load(self, path: str) -> None:
+ pass
+
+ @override
+ def predict_proba(
+ self,
+ X: npt.NDArray,
+ t: npt.NDArray[float],
+ ) -> Tuple[npt.NDArray[float], npt.NDArray[int]]:
+ return (
+ np.full(X.shape[0], self.proba),
+ np.full(X.shape[0], FoodOnForkDetection.SUCCESS),
+ )
+
+
+class FoodOnForkDistanceToNoFOFDetector(FoodOnForkDetector):
+ """
+ A perception algorithm that stores a representative subset of "no FoF" points.
+ It then calculates the average distance between each test point and the nearest
+ no FoF point, and uses a classifier to predict the probability of a
+ test point being FoF based on that distance.
+ """
+
+ # pylint: disable=too-many-instance-attributes
+ # These many are fine.
+
+ AGGREGATORS = {
+ "mean": np.mean,
+ "median": np.median,
+ "max": np.max,
+ "min": np.min,
+ "25p": lambda x: np.percentile(x, 25),
+ "75p": lambda x: np.percentile(x, 75),
+ "90p": lambda x: np.percentile(x, 90),
+ "95p": lambda x: np.percentile(x, 95),
+ }
+
+ def __init__(
+ self,
+ camera_matrix: npt.NDArray,
+ prop_no_fof_points_to_store: float = 0.5,
+ min_points: int = 40,
+ min_distance: float = 0.001,
+ aggregator_name: Optional[str] = "90p",
+ verbose: bool = False,
+ ) -> None:
+ """
+ Initializes the algorithm.
+
+ Parameters
+ ----------
+ camera_matrix: The camera intrinsic matrix (K).
+ prop_no_fof_points_to_store: The proportion of no FoF pointclouds in
+ the train set to set aside for storing no FoF points. Note that not
+ all points in these pointclouds are stored; only those that are >=
+ min_distance m away from the currently stored points.
+ min_points: The minimum number of points in a pointcloud to consider it
+ for comparison. If a pointcloud has fewer points than this, it will
+ return a probability of nan (prediction of UNSURE).
+ min_distance: The minimum distance (m) between stored no FoF points.
+ aggregator_name: The name of the aggregator to use to aggregate the
+ distances between the test point and the stored no FoF points. If None,
+ all aggregators are used. This is typically only useful to compare
+ the performance of different aggregators.
+ verbose: Whether to print debug messages.
+ """
+ # pylint: disable=too-many-arguments
+ # These many are fine.
+
+ super().__init__(verbose)
+ self.camera_matrix = camera_matrix
+ self.prop_no_fof_points_to_store = prop_no_fof_points_to_store
+ self.min_points = min_points
+ self.min_distance = min_distance
+ self.aggregator_name = aggregator_name
+
+ # The attributes that are stored/loaded by the model
+ self.no_fof_points = None
+ self.clf = None
+ self.best_aggregator_name = None
+
+ @property
+ @override
+ def transform_frames(self) -> List[Tuple[str, str]]:
+ return [("forkTip", "camera_color_optical_frame")]
+
+ @staticmethod
+ def distances_between_pointclouds(
+ pointcloud1: npt.NDArray,
+ pointcloud2: npt.NDArray,
+ ) -> npt.NDArray:
+ """
+ For every point in pointcloud1, gets the minimum distance to points in
+ pointcloud2. Note that this is not
+ symmetric; the order of the pointclouds matters.
+
+ Parameters
+ ----------
+ pointcloud1: The test pointcloud. Size (n, k).
+ pointcloud2: The training pointcloud. Size (m, k).
+
+ Returns
+ -------
+ distances: The minimum distance from each point in pointcloud1 to points
+ in pointcloud2. Size (n,).
+ """
+ return np.min(pairwise_distances(pointcloud1, pointcloud2), axis=1)
+
+ def get_representative_points(self, pointcloud: npt.NDArray) -> npt.NDArray:
+ """
+ Returns a subset of points from the pointcloud such that every point in
+ pointcloud is no more than min_distance away from one of the representative
+ points.
+
+ Parameters
+ ----------
+ pointcloud: The pointcloud to get representative points from. Size (n, k).
+
+ Returns
+ -------
+ representative_points: The subset of points from the pointcloud. Size (m, k).
+ """
+ if self.verbose:
+ print("Getting a subset of representative points")
+
+ # Include the first point
+ representative_points = pointcloud[0:1, :]
+
+ # Add points that are >= min_distance m away from the stored points
+ for i in range(1, len(pointcloud)):
+ if self.verbose:
+ print(f"Point {i}/{len(pointcloud)}")
+ contender_point = pointcloud[i]
+ # Get the distance between the contender point and the representative points
+ distance = np.min(
+ np.linalg.norm(representative_points - contender_point, axis=1)
+ )
+ if distance >= self.min_distance:
+ representative_points = np.vstack(
+ [representative_points, contender_point]
+ )
+ return representative_points
+
+ @override
+ def fit(
+ self,
+ X: npt.NDArray,
+ y: npt.NDArray[int],
+ t: npt.NDArray[float],
+ viz_save_dir: Optional[str] = None,
+ ) -> None:
+ # pylint: disable=too-many-locals, too-many-branches, too-many-statements
+ # This is the main logic of the algorithm, so it's okay to have a lot of
+ # local variables.
+
+ # Only keep the datapoints where the transform is not 0
+ i_to_keep = np.where(np.logical_not(np.all(t == 0, axis=(2, 3))))[0]
+ X = X[i_to_keep]
+ y = y[i_to_keep]
+ t = t[i_to_keep]
+
+ # Get the most up-to-date camera matrix
+ if self.camera_info is not None:
+ self.camera_matrix = np.array(self.camera_info.k)
+
+ # Convert all images to pointclouds, excluding those with too few points
+ pointclouds = []
+ y_pointclouds = []
+ for i, img in enumerate(X):
+ pointcloud = depth_img_to_pointcloud(
+ img,
+ *self.crop_top_left,
+ f_x=self.camera_matrix[0],
+ f_y=self.camera_matrix[4],
+ c_x=self.camera_matrix[2],
+ c_y=self.camera_matrix[5],
+ transform=t[i, 0, :, :],
+ )
+ if len(pointcloud) >= self.min_points:
+ pointclouds.append(pointcloud)
+ y_pointclouds.append(y[i])
+
+ # Convert to np arrays
+ # Pointclouds must be dtype object to store arrays of different lengths
+ pointclouds = np.array(pointclouds, dtype=object)
+ y_pointclouds = np.array(y_pointclouds)
+ if self.verbose:
+ print(
+ f"Converted {X.shape[0]} depth images to {pointclouds.shape[0]} pointclouds"
+ )
+
+ # Split the no FoF and FoF pointclouds. The "train set" consists of only
+ # no FoF pointclouds, and is used to find a representative subset of points
+ # to store. The "val set" consists of both no FoF and FoF pointclouds, and
+ # is used to train the classifier.
+ no_fof_pointclouds = pointclouds[y_pointclouds == FoodOnForkLabel.NO_FOOD.value]
+ fof_pointclouds = pointclouds[y_pointclouds == FoodOnForkLabel.FOOD.value]
+ no_fof_pointclouds_train, no_fof_pointclouds_val = train_test_split(
+ no_fof_pointclouds,
+ train_size=self.prop_no_fof_points_to_store,
+ random_state=self.seed,
+ )
+ val_pointclouds = np.concatenate([no_fof_pointclouds_val, fof_pointclouds])
+ val_labels = np.concatenate(
+ [
+ np.zeros((no_fof_pointclouds_val.shape[0],)),
+ np.ones((fof_pointclouds.shape[0],)),
+ ]
+ )
+ if self.verbose:
+ print("Split the no FoF pointclouds into train and val")
+
+ # Store a representative subset of the points
+ all_no_fof_pointclouds_train = np.concatenate(no_fof_pointclouds_train)
+ self.no_fof_points = self.get_representative_points(
+ all_no_fof_pointclouds_train
+ )
+ if self.verbose:
+ print(
+ f"Stored a representative subset of {self.no_fof_points.shape[0]}/"
+ f"{all_no_fof_pointclouds_train.shape[0]} no FoF pointclouds"
+ )
+
+ # Get the aggregators
+ if self.aggregator_name is None:
+ aggregator_names = list(self.AGGREGATORS.keys())
+ else:
+ aggregator_names = [self.aggregator_name]
+
+ # Get the distances from each point in the "val set" to the stored points
+ val_distances = {name: [] for name in aggregator_names}
+ for i, pointcloud in enumerate(val_pointclouds):
+ if self.verbose:
+ print(
+ "Computing distance to stored points for val point "
+ f"{i}/{val_pointclouds.shape[0]}"
+ )
+ point_distances = (
+ FoodOnForkDistanceToNoFOFDetector.distances_between_pointclouds(
+ pointcloud, self.no_fof_points
+ )
+ )
+ for name in aggregator_names:
+ aggregator = self.AGGREGATORS[name]
+ distance = aggregator(point_distances)
+ val_distances[name].append(distance)
+ val_distances = {
+ name: np.array(val_distances[name]) for name in aggregator_names
+ }
+
+ # Split the validation set into train and val. This is to pick the best
+ # aggregator to use.
+ val_train_i, val_val_i = train_test_split(
+ np.arange(val_labels.shape[0]),
+ train_size=0.8,
+ random_state=self.seed,
+ stratify=val_labels,
+ )
+
+ # Train the classifier(s)
+ f1_scores = {}
+ clfs = {}
+ for name in aggregator_names:
+ # Train the classifier
+ if self.verbose:
+ print(f"Training the classifier for aggregator {name}")
+ clf = LogisticRegression(random_state=self.seed, penalty=None)
+ clf.fit(
+ val_distances[name].reshape(-1, 1)[val_train_i, :],
+ val_labels[val_train_i],
+ )
+ clfs[name] = clf
+ if self.verbose:
+ print(
+ f"Trained the classifier for aggregator {name}, got coeff "
+ f"{clf.coef_} and intercept {clf.intercept_}"
+ )
+
+ # Get the f1 score
+ y_pred = clf.predict(val_distances[name].reshape(-1, 1)[val_val_i])
+ y_true = val_labels[val_val_i]
+ f1_scores[name] = f1_score(y_true, y_pred)
+ if self.verbose:
+ print(f"F1 score for aggregator {name}: {f1_scores[name]}")
+
+ # Save a visualization of the classifier
+ if viz_save_dir is not None:
+ max_aggregated_distance = max(
+ np.max(val_distances[name][val_val_i]) for name in aggregator_names
+ )
+ for name in aggregator_names:
+ # Create a scatterplot where the x-axis is the distance in val_val
+ # and the y-axis is the label. Add some y-jitter to make it easier
+ # to see the points.
+ y_true = val_labels[val_val_i] + np.random.normal(
+ 0, 0.1, val_val_i.shape[0]
+ )
+ fig, ax = plt.subplots(figsize=(5, 4))
+ ax.scatter(
+ val_distances[name][val_val_i], y_true, label="True", alpha=0.5
+ )
+ ax.set_xlim(0, max_aggregated_distance)
+
+ # Add a line for the probability predictions of num_points over the range
+ # of distances
+ num_points = 100
+ distances = np.linspace(0.0, max_aggregated_distance, num_points)
+ probas = clfs[name].predict_proba(distances.reshape(-1, 1))[:, 1]
+ ax.plot(distances, probas, label="Classifier Probabilities")
+
+ # Add a title
+ ax.set_title(
+ f"Classifier for Aggregator {name}. F1 Score: {f1_scores[name]}"
+ )
+
+ # Save the figure
+ fig.savefig(
+ os.path.join(
+ viz_save_dir,
+ f"classifier_{clf.__class__.__name__}_aggregator_{name}.png",
+ )
+ )
+
+ # Pick the best aggregator
+ self.best_aggregator_name, best_f1_score = max(
+ f1_scores.items(), key=lambda x: x[1]
+ )
+ if self.verbose:
+ print(
+ f"Best aggregator: {self.best_aggregator_name} with f1 score {best_f1_score}"
+ )
+ self.clf = clfs[self.best_aggregator_name]
+
+ @override
+ def save(self, path: str) -> str:
+ if (
+ self.no_fof_points is None
+ or self.clf is None
+ or self.best_aggregator_name is None
+ ):
+ raise ValueError(
+ "The model has not been trained yet. Call fit before saving."
+ )
+ # If the path has an extension, remove it.
+ path = os.path.splitext(path)[0]
+ np.savez_compressed(
+ path,
+ no_fof_points=self.no_fof_points,
+ clf=np.array([self.clf], dtype=object),
+ best_aggregator_name=self.best_aggregator_name,
+ )
+ return path + ".npz"
+
+ @override
+ def load(self, path: str) -> None:
+ ext = os.path.splitext(path)[1]
+ if len(ext) == 0:
+ path = path + ".npz"
+ params = np.load(path, allow_pickle=True)
+ self.no_fof_points = params["no_fof_points"]
+ self.clf = params["clf"][0]
+ self.best_aggregator_name = str(params["best_aggregator_name"])
+ if self.verbose:
+ print(
+ f"Loaded model with intercept {self.clf.intercept_} and coef {self.clf.coef_} "
+ f"and best aggregator {self.best_aggregator_name} and num stored points "
+ f"{self.no_fof_points.shape[0]}"
+ )
+
+ @override
+ def predict_proba(
+ self,
+ X: npt.NDArray,
+ t: npt.NDArray[float],
+ ) -> Tuple[npt.NDArray[float], npt.NDArray[int]]:
+ probas = []
+ statuses = []
+
+ # Get the prediction per image.
+ if self.verbose:
+ inference_times = []
+ for i, img in enumerate(X):
+ if self.verbose:
+ start_time = time.time()
+
+ # If all elements of the transform are 0, set the proba to nan
+ if np.all(np.isclose(t[i, 0, :, :], 0.0)):
+ probas.append(np.nan)
+ statuses.append(FoodOnForkDetection.ERROR_NO_TRANSFORM)
+ continue
+
+ # Convert the image to a pointcloud
+ pointcloud = depth_img_to_pointcloud(
+ img,
+ *self.crop_top_left,
+ f_x=self.camera_matrix[0],
+ f_y=self.camera_matrix[4],
+ c_x=self.camera_matrix[2],
+ c_y=self.camera_matrix[5],
+ transform=t[i, 0, :, :],
+ )
+
+ # If there are too few points, set the proba to nan
+ if len(pointcloud) < self.min_points:
+ probas.append(np.nan)
+ statuses.append(FoodOnForkDetection.ERROR_TOO_FEW_POINTS)
+ continue
+
+ # If there are enough points, use the classifier to predict the probability
+ # of food on the fork. Else, return an error status
+ distances = FoodOnForkDistanceToNoFOFDetector.distances_between_pointclouds(
+ pointcloud, self.no_fof_points
+ )
+ distance = self.AGGREGATORS[self.best_aggregator_name](distances)
+ proba = self.clf.predict_proba(np.array([[distance]]))[0, 1]
+ probas.append(proba)
+ statuses.append(FoodOnForkDetection.SUCCESS)
+ if self.verbose:
+ inference_times.append(time.time() - start_time)
+ if self.verbose:
+ print(
+ f"Inference Time: min {np.min(inference_times)}, max {np.max(inference_times)}, "
+ f"mean {np.mean(inference_times)}, 25th percentile {np.percentile(inference_times, 25)}, "
+ f"50th percentile {np.percentile(inference_times, 50)}, "
+ f"75th percentile {np.percentile(inference_times, 75)}."
+ )
+
+ return np.array(probas), np.array(statuses, dtype=int)
+
+ @override
+ def predict(
+ self,
+ X: npt.NDArray,
+ t: npt.NDArray[float],
+ lower_thresh: float,
+ upper_thresh: float,
+ proba: Optional[npt.NDArray] = None,
+ statuses: Optional[npt.NDArray[int]] = None,
+ ) -> Tuple[npt.NDArray[int], npt.NDArray[int]]:
+ # pylint: disable=too-many-arguments
+ # These many are fine.
+ if proba is None or statuses is None:
+ proba, statuses = self.predict_proba(X, t)
+ return (
+ np.where(
+ (proba < lower_thresh)
+ | (statuses == FoodOnForkDetection.ERROR_TOO_FEW_POINTS),
+ FoodOnForkLabel.NO_FOOD.value,
+ np.where(
+ proba > upper_thresh,
+ FoodOnForkLabel.FOOD.value,
+ FoodOnForkLabel.UNSURE.value,
+ ),
+ ),
+ statuses,
+ )
+
+ @override
+ def overlay_debug_info(self, img: npt.NDArray, t: npt.NDArray) -> npt.NDArray:
+ # pylint: disable=too-many-locals
+ # This is done to make it clear what the camera matrix values are.
+
+ # First, convert all no_fof_points back to the camera frame by applying
+ # the inverse of the homogenous transform t[0, :, :]
+ no_fof_points_homogenous = np.hstack(
+ [self.no_fof_points, np.ones((self.no_fof_points.shape[0], 1))]
+ )
+ no_fof_points_camera = np.dot(
+ np.linalg.inv(t[0, :, :]), no_fof_points_homogenous.T
+ ).T[:, :3]
+
+ # For every point in the no_fof_points, convert them back into (u,v) pixel
+ # coordinates.
+ no_fof_points_mm = (no_fof_points_camera * 1000).astype(int)
+ f_x = self.camera_matrix[0]
+ f_y = self.camera_matrix[4]
+ c_x = self.camera_matrix[2]
+ c_y = self.camera_matrix[5]
+ us = (f_x * no_fof_points_mm[:, 0] / no_fof_points_mm[:, 2] + c_x).astype(int)
+ vs = (f_y * no_fof_points_mm[:, 1] / no_fof_points_mm[:, 2] + c_y).astype(int)
+
+ # For every point, draw a circle around that point in the image
+ color = (0, 0, 0)
+ alpha = 0.75
+ radius = 5
+ img_with_debug_info = img.copy()
+ for u, v in zip(us, vs):
+ cv2.circle(img_with_debug_info, (u, v), radius, color, -1)
+ return cv2.addWeighted(img_with_debug_info, alpha, img.copy(), 1 - alpha, 0)
+
+ @override
+ def visualize_img(self, img: npt.NDArray, t: npt.NDArray) -> None:
+ # Convert the image to a pointcloud
+ pointclouds = [
+ depth_img_to_pointcloud(
+ img,
+ *self.crop_top_left,
+ f_x=self.camera_matrix[0],
+ f_y=self.camera_matrix[4],
+ c_x=self.camera_matrix[2],
+ c_y=self.camera_matrix[5],
+ transform=t[0, :, :],
+ )
+ ]
+ colors = [[0, 0, 1]]
+ sizes = [5]
+ markerstyles = ["o"]
+ labels = ["Test"]
+
+ if self.no_fof_points is not None:
+ print(f"Visualizing the {self.no_fof_points.shape[0]} stored no FoF points")
+ pointclouds.append(self.no_fof_points)
+ colors.append([1, 0, 0])
+ sizes.append(5)
+ markerstyles.append("^")
+ labels.append("Train")
+
+ show_3d_scatterplot(
+ pointclouds,
+ colors,
+ sizes,
+ markerstyles,
+ labels,
+ title="Img vs. Stored No FoF Points",
+ )
diff --git a/ada_feeding_perception/ada_feeding_perception/food_on_fork_train_test.py b/ada_feeding_perception/ada_feeding_perception/food_on_fork_train_test.py
new file mode 100644
index 00000000..84c9caa2
--- /dev/null
+++ b/ada_feeding_perception/ada_feeding_perception/food_on_fork_train_test.py
@@ -0,0 +1,902 @@
+"""
+This script takes in a variety of command line arguments and then trains and test a
+FoodOnForkDetector as configured by the arguments. Note that although this is not
+a ROS node, it relies on helper functions and types in ada_feeding, ada_feeding_msgs,
+and ada_feeding_perception packages. The easiest way to access those is to build
+your workspace and source it, before running this script.
+"""
+
+# Standard Imports
+import argparse
+import json
+import os
+import textwrap
+import time
+from typing import Any, Dict, List, Optional, Tuple
+
+# Third-party imports
+from builtin_interfaces.msg import Time
+import cv2
+from cv_bridge import CvBridge
+from geometry_msgs.msg import (
+ Quaternion,
+ Transform,
+ TransformStamped,
+ Vector3,
+)
+import numpy as np
+import numpy.typing as npt
+import pandas as pd
+from rosbags.rosbag2 import Reader
+from rosbags.serde import deserialize_cdr
+from sensor_msgs.msg import CameraInfo
+from sklearn.metrics import (
+ accuracy_score,
+ confusion_matrix,
+)
+from sklearn.model_selection import train_test_split
+from std_msgs.msg import Header
+from tf2_ros.buffer import Buffer
+
+# Local imports
+from ada_feeding.helpers import import_from_string
+from ada_feeding_perception.food_on_fork_detectors import (
+ FoodOnForkDetector,
+ FoodOnForkLabel,
+)
+from ada_feeding_perception.helpers import ros_msg_to_cv2_image
+from ada_feeding_perception.depth_post_processors import (
+ create_spatial_post_processor,
+ create_temporal_post_processor,
+)
+
+
+def read_args() -> argparse.Namespace:
+ """
+ Read the command line arguments.
+
+ Returns
+ -------
+ args: argparse.Namespace
+ The command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description=(
+ "Train and test one or more FoodOnForkDetectors on an offline dataset."
+ )
+ )
+
+ # Configure the models
+ parser.add_argument(
+ "--model-classes",
+ help=(
+ "A JSON-encoded string where keys are an arbitrary model ID and "
+ "values are the class names to use for that model. e.g., "
+ '{"dummy_detector": "ada_feeding_perception.food_on_fork_detectors.FoodOnForkDummyDetector"}'
+ ),
+ required=True,
+ )
+ parser.add_argument(
+ "--model-kwargs",
+ default="{}",
+ help=(
+ "A JSON-encoded string where keys are the model ID and values are "
+ "a dictionary of keyword arguments to pass to the model's constructor. e.g., "
+ '{"dummy_detector": {"proba": 0.1}}'
+ ),
+ )
+
+ # Configure post-processing of the depth images
+ parser.add_argument(
+ "--temporal-window-size",
+ default=None,
+ type=int,
+ help=(
+ "The size of the temporal window to use for post-processing. If unset, "
+ "no temporal post-processing will be done. See depth_post_processors.py "
+ "for more details."
+ ),
+ )
+ parser.add_argument(
+ "--spatial-num-pixels",
+ default=None,
+ type=int,
+ help=(
+ "The number of pixels to use for the spatial post-processing. If unset, "
+ "no spatial post-processing will be done. See depth_post_processors.py "
+ "for more details."
+ ),
+ )
+
+ # Configure the cropping/masking of depth images. These should exactly match
+ # the cropping/masking done in the real-time detector (in config/food_on_fork_detection.yaml).
+ parser.add_argument(
+ "--crop-top-left",
+ default=(0, 0),
+ type=int,
+ nargs="+",
+ help=("The top-left corner of the crop rectangle. The format is (u, v)."),
+ )
+ parser.add_argument(
+ "--crop-bottom-right",
+ default=(640, 480),
+ type=int,
+ nargs="+",
+ help=("The bottom-right corner of the crop rectangle. The format is (u, v)."),
+ )
+ parser.add_argument(
+ "--depth-min-mm",
+ default=0,
+ type=int,
+ help=("The minimum depth value to consider in the depth images."),
+ )
+ parser.add_argument(
+ "--depth-max-mm",
+ default=20000,
+ type=int,
+ help=("The maximum depth value to consider in the depth images."),
+ )
+
+ # Configure the dataset
+ parser.add_argument(
+ "--data-dir",
+ default="../data/food_on_fork",
+ help=(
+ "The directory containing the training and testing data. This path should "
+ "have a file called `bags_metadata.csv` that contains the labels for bagfiles, "
+ "and one folder per bagfile referred to in the CSV. This path should be "
+ "relative to **this file's** location."
+ ),
+ )
+ parser.add_argument(
+ "--depth-topic",
+ default="/local/camera/aligned_depth_to_color/image_raw",
+ help=("The topic to use for depth images."),
+ )
+ parser.add_argument(
+ "--color-topic",
+ default="/local/camera/color/image_raw/compressed",
+ help=("The topic to use for color images. Used for debugging."),
+ )
+ parser.add_argument(
+ "--camera-info-topic",
+ default="/local/camera/color/camera_info",
+ help=("The topic to use for camera info."),
+ )
+ parser.add_argument(
+ "--exclude-motion",
+ default=False,
+ action="store_true",
+ help=("If set, exclude images when the robot arm is moving in the dataset."),
+ )
+ parser.add_argument(
+ "--rosbags-select",
+ default=[],
+ type=str,
+ nargs="+",
+ help="If set, only rosbags listed here will be included",
+ )
+ parser.add_argument(
+ "--rosbags-skip",
+ default=[],
+ type=str,
+ nargs="+",
+ help="If set, rosbags listed here will be excluded",
+ )
+
+ # Configure the training and testing operations
+ parser.add_argument(
+ "--no-train",
+ default=False,
+ action="store_true",
+ help="If set, do not train the models and instead only test them.",
+ )
+ parser.add_argument(
+ "--seed",
+ default=None,
+ type=int,
+ help=(
+ "The random seed to use for the train-test split and in the detector. "
+ "If unspecified, the seed will be the current time."
+ ),
+ )
+ parser.add_argument(
+ "--train-set-size",
+ default=0.8,
+ type=float,
+ help="The fraction of the dataset to use for training",
+ )
+ parser.add_argument(
+ "--model-dir",
+ default="../model",
+ help=(
+ "The directory to save and load the trained model to/from. The path should be "
+ "relative to **this file's** location. "
+ ),
+ )
+ parser.add_argument(
+ "--output-dir",
+ default="../model/results",
+ help=(
+ "The directory to save the train and test results to. The path should be "
+ "relative to **this file's** location. "
+ ),
+ )
+ parser.add_argument(
+ "--lower-thresh",
+ default=0.5,
+ type=float,
+ help=(
+ "If the predicted probability of food on fork is <= this value, the "
+ "detector will predict that there is no food on the fork."
+ ),
+ )
+ parser.add_argument(
+ "--upper-thresh",
+ default=0.5,
+ type=float,
+ help=(
+ "If the predicted probability of food on fork is > this value, the "
+ "detector will predict that there is food on the fork."
+ ),
+ )
+ parser.add_argument(
+ "--max-eval-n",
+ default=None,
+ type=int,
+ help=(
+ "The maximum number of evaluations to perform. If None, all evaluations "
+ "will be performed. Typically used when debugging a detector."
+ ),
+ )
+ parser.add_argument(
+ "--viz-rosbags",
+ default=False,
+ action="store_true",
+ help=(
+ "If set, render the color images in the rosbag and the label while "
+ "loading the data. This is useful for find-tuning ground-truth labels."
+ ),
+ )
+ parser.add_argument(
+ "--viz-evaluation",
+ default=False,
+ action="store_true",
+ help=(
+ "If set, visualize all images where the model was wrong. This is useful "
+ "for debugging, but note that it will block after every wrong prediction "
+ "until the visualization window is closed."
+ ),
+ )
+ parser.add_argument(
+ "--viz-fit-save-dir",
+ default=None,
+ help=(
+ "If set, visualize the fit of the model and save the images to this directory."
+ ),
+ )
+
+ return parser.parse_args()
+
+
+def load_data(
+ models: Dict[str, FoodOnForkDetector],
+ data_dir: str,
+ depth_topic: str,
+ color_topic: str,
+ camera_info_topic: str,
+ crop_top_left: Tuple[int, int],
+ crop_bottom_right: Tuple[int, int],
+ depth_min_mm: int,
+ depth_max_mm: int,
+ exclude_motion: bool,
+ rosbags_select: Optional[List[str]] = None,
+ rosbags_skip: Optional[List[str]] = None,
+ temporal_window_size: Optional[int] = None,
+ spatial_num_pixels: Optional[int] = None,
+ viz: bool = False,
+) -> Tuple[npt.NDArray, npt.NDArray[int], CameraInfo]:
+ """
+ Load the data specified in the command line arguments.
+
+ Parameters
+ ----------
+ models: dict
+ A dictionary where keys are the model ID and values are the model instances.
+ This is necessary to get the transforms that each model is looking for.
+ data_dir: str
+ The directory containing the training and testing data. The path should be
+ relative to **this file's** location. This directory should have two
+ subdirectories: 'food' and 'no_food', each containing either .png files
+ corresponding to depth images or ROS bag files with the topics specified
+ in command line arguments.
+ depth_topic: str
+ The topic to use for depth images.
+ color_topic: str
+ The topic to use for color images. Used for debugging.
+ camera_info_topic: str
+ The topic to use for camera info.
+ crop_top_left, crop_bottom_right: Tuple[int, int]
+ The top-left and bottom-right corners of the crop rectangle.
+ depth_min_mm, depth_max_mm: int
+ The minimum and maximum depth values to consider in the depth images.
+ exclude_motion: bool
+ If True, exclude images when the robot arm is moving in the dataset.
+ rosbags_select: List[str], optional
+ If set, only rosbags in this list will be included
+ rosbags_skip: List[str], optional
+ If set, rosbags in this list will be excluded
+ temporal_window_size: int, optional
+ The size of the temporal window to use for post-processing. If unset,
+ no temporal post-processing will be done.
+ spatial_num_pixels: int, optional
+ The number of pixels to use for the spatial post-processing. If unset,
+ no spatial post-processing will be done.
+ viz: bool, optional
+ If True, visualize the depth images as they are loaded.
+
+ Returns
+ -------
+ X: npt.NDArray
+ The depth images to predict on.
+ y: npt.NDArray[int]
+ The labels for whether there is food on the fork.
+ t: dict
+ For each model, the requested transforms.
+ camera_info: CameraInfo
+ The camera info for the depth images. We assume it is static across all
+ depth images.
+ """
+ # pylint: disable=too-many-locals, too-many-arguments, too-many-branches, too-many-statements
+ # Okay since we want to make it a flexible method.
+ print("Loading data...")
+
+ # Replace the optional arguments
+ if rosbags_select is None:
+ rosbags_select = []
+ if rosbags_skip is None:
+ rosbags_skip = []
+
+ # Set up the post-processors
+ bridge = CvBridge()
+ post_processors = []
+ if temporal_window_size is not None:
+ post_processors.append(
+ create_temporal_post_processor(temporal_window_size, bridge)
+ )
+ if spatial_num_pixels is not None:
+ post_processors.append(
+ create_spatial_post_processor(spatial_num_pixels, bridge)
+ )
+
+ absolute_data_dir = os.path.join(os.path.dirname(__file__), data_dir)
+
+ # Initialize the data
+ w = crop_bottom_right[0] - crop_top_left[0]
+ h = crop_bottom_right[1] - crop_top_left[1]
+ X = np.zeros((0, h, w), dtype=np.uint16)
+ y = np.zeros(0, dtype=int)
+ tf_buffer = Buffer()
+ model_to_frames = {
+ model_id: model.transform_frames for model_id, model in models.items()
+ }
+ t = {
+ model_id: np.zeros((0, len(frames), 4, 4), dtype=float)
+ for model_id, frames in model_to_frames.items()
+ }
+ all_frames = []
+ for model in models.values():
+ all_frames.extend(model.transform_frames)
+ all_frames = list(set(all_frames))
+
+ # Load the metadata
+ metadata = pd.read_csv(os.path.join(absolute_data_dir, "bags_metadata.csv"))
+ bagname_to_annotations = {}
+ for _, row in metadata.iterrows():
+ rosbag_name = row["rosbag_name"]
+ time_from_start = row["time_from_start"]
+ food_on_fork = row["food_on_fork"]
+ arm_moving = row["arm_moving"]
+ if rosbag_name not in bagname_to_annotations:
+ bagname_to_annotations[rosbag_name] = []
+ bagname_to_annotations[rosbag_name].append(
+ (time_from_start, food_on_fork, arm_moving)
+ )
+
+ # Sort the rosbag names in chronological order based on the first message.
+ # This is necessary so the latest transform is accurate.
+ rosbag_name_to_first_timestamp = {}
+ for rosbag_name in bagname_to_annotations:
+ with Reader(os.path.join(absolute_data_dir, rosbag_name)) as reader:
+ for connection, timestamp, _ in reader.messages():
+ rosbag_name_to_first_timestamp[rosbag_name] = timestamp
+ break
+ sorted_rosbag_names = [
+ k
+ for k, _ in sorted(
+ rosbag_name_to_first_timestamp.items(), key=lambda item: item[1]
+ )
+ ]
+
+ # Load the data
+ camera_info = None
+ num_images_no_points = 0
+ for rosbag_name in sorted_rosbag_names:
+ annotations = bagname_to_annotations[rosbag_name]
+ if (len(rosbags_select) > 0 and rosbag_name not in rosbags_select) or (
+ len(rosbags_skip) > 0 and rosbag_name in rosbags_skip
+ ):
+ print(f"Skipping rosbag {rosbag_name}")
+ continue
+ annotations.sort()
+ i = 0
+ num_images_no_points = 0
+ with Reader(os.path.join(absolute_data_dir, rosbag_name)) as reader:
+ # Get the depth message count
+ for connection in reader.connections:
+ if connection.topic == depth_topic:
+ depth_msg_count = connection.msgcount
+ break
+ # Extend X and y by depth_msg_count
+ j = y.shape[0]
+ X = np.concatenate((X, np.zeros((depth_msg_count, h, w), dtype=np.uint16)))
+ y = np.concatenate((y, np.zeros(depth_msg_count, dtype=int)))
+ for model_id in t:
+ t[model_id] = np.concatenate(
+ (
+ t[model_id],
+ np.zeros((depth_msg_count, len(all_frames), 4, 4), dtype=float),
+ )
+ )
+
+ start_time = None
+ for connection, timestamp, rawdata in reader.messages():
+ if start_time is None:
+ start_time = timestamp
+ # Depth Image
+ if connection.topic == depth_topic:
+ msg = deserialize_cdr(rawdata, connection.msgtype)
+ elapsed_time = (timestamp - start_time) / 10.0**9
+ while (
+ i < len(annotations) - 1
+ and elapsed_time > annotations[i + 1][0]
+ ):
+ i += 1
+ arm_moving = annotations[i][2]
+ if exclude_motion and arm_moving:
+ # Skip images when the robot arm is moving
+ continue
+ if annotations[i][1] == FoodOnForkLabel.FOOD.value:
+ label = 1
+ elif annotations[i][1] == FoodOnForkLabel.NO_FOOD.value:
+ label = 0
+ else:
+ # Skip images with unknown label
+ continue
+ # Post-process the image
+ for post_processor in post_processors:
+ msg = post_processor(msg)
+ img = ros_msg_to_cv2_image(msg, bridge)
+ img = img[
+ crop_top_left[1] : crop_bottom_right[1],
+ crop_top_left[0] : crop_bottom_right[0],
+ ]
+ img = np.where(
+ (img >= depth_min_mm) & (img <= depth_max_mm), img, 0
+ )
+ if np.all(img == 0):
+ num_images_no_points += 1
+ X[j, :, :] = img
+ y[j] = label
+ # Get the latest transform
+ transforms = FoodOnForkDetector.get_transforms(
+ all_frames,
+ tf_buffer,
+ )
+ for model, frames in model_to_frames.items():
+ frames_i = [all_frames.index(frame) for frame in frames]
+ t[model][j, :, :, :] = np.array(
+ [transforms[frame_i] for frame_i in frames_i],
+ dtype=float,
+ ).reshape((len(frames), 4, 4))
+ j += 1
+ # Camera Info
+ elif connection.topic == camera_info_topic and camera_info is None:
+ camera_info = deserialize_cdr(rawdata, connection.msgtype)
+ # RGB Image
+ elif viz and connection.topic == color_topic:
+ msg = deserialize_cdr(rawdata, connection.msgtype)
+ print(f"Elapsed Time: {(timestamp - start_time) / 10.0**9}")
+ img = ros_msg_to_cv2_image(msg, bridge)
+ # A box around the forktip
+ x0, y0 = crop_top_left
+ x1, y1 = crop_bottom_right
+ fof_color = (0, 255, 0)
+ no_fof_color = (255, 0, 0)
+ color = fof_color if j == 0 or y[j - 1] == 1 else no_fof_color
+ img = cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
+ img = cv2.circle(
+ img, ((x0 + x1) // 2, (y0 + y1) // 2), 5, color, -1
+ )
+ cv2.imshow("RGB Image", img)
+ cv2.waitKey(1)
+ # TF Topics
+ elif connection.topic in {"/tf", "/tf_static"}:
+ msg = deserialize_cdr(rawdata, connection.msgtype)
+ for transform in msg.transforms:
+ transform = TransformStamped(
+ header=Header(
+ stamp=Time(
+ sec=transform.header.stamp.sec,
+ nanosec=transform.header.stamp.nanosec,
+ ),
+ frame_id=transform.header.frame_id,
+ ),
+ child_frame_id=transform.child_frame_id,
+ transform=Transform(
+ translation=Vector3(
+ x=transform.transform.translation.x,
+ y=transform.transform.translation.y,
+ z=transform.transform.translation.z,
+ ),
+ rotation=Quaternion(
+ x=transform.transform.rotation.x,
+ y=transform.transform.rotation.y,
+ z=transform.transform.rotation.z,
+ w=transform.transform.rotation.w,
+ ),
+ ),
+ )
+ if connection.topic == "/tf":
+ tf_buffer.set_transform(transform, "default_authority")
+ else:
+ tf_buffer.set_transform_static(
+ transform, "default_authority"
+ )
+
+ # Truncate extra all-zero rows on the end of Z and y
+ print(f"Proportion of img with no pixels: {num_images_no_points/j}")
+ X = X[:j]
+ y = y[:j]
+ for model_id in t:
+ t[model_id] = t[model_id][:j]
+
+ print(f"Done loading data. {X.shape[0]} depth images loaded.")
+ return X, y, t, camera_info
+
+
+def load_models(
+ model_classes: str,
+ model_kwargs: str,
+ seed: int,
+ crop_top_left: Tuple[int, int],
+ crop_bottom_right: Tuple[int, int],
+) -> Dict[str, FoodOnForkDetector]:
+ """
+ Load the models specified in the command line arguments.
+
+ Parameters
+ ----------
+ model_classes: str
+ A JSON-encoded dictionary where keys are an arbitrary model ID and values
+ are the class names to use for that model.
+ model_kwargs: str
+ A JSON-encoded dictionary where keys are the model ID and values are a
+ dictionary of keyword arguments to pass to the model's constructor.
+ seed: int
+ The random seed to use in the detector.
+ crop_top_left, crop_bottom_right: Tuple[int, int]
+ The top-left and bottom-right corners of the crop rectangle.
+
+ Returns
+ -------
+ models: dict
+ A dictionary where keys are the model ID and values are the model instances.
+ """
+ print("Loading models...")
+
+ # Parse the JSON strings
+ model_classes = json.loads(model_classes)
+ model_kwargs = json.loads(model_kwargs)
+
+ models = {}
+ for model_id, model_class in model_classes.items():
+ # Load the class
+ model_class = import_from_string(model_class)
+
+ # Get the kwargs
+ kwargs = model_kwargs.get(model_id, {})
+
+ # Create the model
+ models[model_id] = model_class(**kwargs)
+ models[model_id].seed = seed
+ models[model_id].crop_top_left = crop_top_left
+ models[model_id].crop_bottom_right = crop_bottom_right
+
+ print(f"Done loading models with IDs {list(model_classes.keys())}.")
+ return models
+
+
+def train_models(
+ models: Dict[str, Any],
+ X: npt.NDArray,
+ y: npt.NDArray,
+ t: Dict[str, npt.NDArray],
+ model_dir: str,
+ viz_fit_save_dir: Optional[str] = None,
+) -> None:
+ """
+ Train the models on the training data.
+
+ Parameters
+ ----------
+ models: dict
+ A dictionary where keys are the model ID and values are the model instances.
+ X: npt.NDArray
+ The depth images to train on.
+ y: npt.NDArray
+ The labels for the depth images.
+ t: dict
+ For each model, the requested transforms.
+ model_dir: str
+ The directory to save the trained model to. The path should be
+ relative to **this file's** location.
+ viz_fit_save_dir: str, optional
+ If set, visualize the fit of the model and save the images to this directory.
+ """
+ # pylint: disable=too-many-arguments
+ # This is okay.
+ absolute_model_dir = os.path.join(os.path.dirname(__file__), model_dir)
+
+ for model_id, model in models.items():
+ print(f"Training model {model_id}...")
+ model.fit(X, y, t[model_id], viz_fit_save_dir)
+ save_path = model.save(os.path.join(absolute_model_dir, model_id))
+ print(f"Done. Saved to '{save_path}'.")
+
+
+def evaluate_models(
+ models: Dict[str, Any],
+ train_X: npt.NDArray,
+ test_X: npt.NDArray,
+ train_y: npt.NDArray,
+ test_y: npt.NDArray,
+ train_t: Dict[str, npt.NDArray],
+ test_t: Dict[str, npt.NDArray],
+ model_dir: str,
+ output_dir: str,
+ lower_thresh: float,
+ upper_thresh: float,
+ max_eval_n: Optional[int] = None,
+ viz: bool = False,
+) -> None:
+ """
+ Test the models on the testing data.
+
+ Parameters
+ ----------
+ models: dict
+ A dictionary where keys are the model ID and values are the model instances.
+ train_X, test_X: npt.NDArray
+ The depth images to test on.
+ train_y, test_Y: npt.NDArray
+ The labels for the depth images.
+ train_t, test_t: dict
+ For each model, the requested transforms.
+ model_dir: str
+ The directory to load the trained model from. The path should be
+ relative to **this file's** location.
+ output_dir: str
+ The directory to save the train and test results to. The path should be
+ relative to **this file's** location.
+ lower_thresh: float
+ If the predicted probability of food on fork is <= this value, the
+ detector will predict that there is no food on the fork.
+ upper_thresh: float
+ If the predicted probability of food on fork is > this value, the
+ detector will predict that there is food on the fork.
+ max_eval_n: int, optional
+ The maximum number of evaluations to perform. If None, all evaluations
+ will be performed. Typically used when debugging a detector.
+ viz: bool, optional
+ If True, visualize the depth images as they are evaluated.
+ """
+ # pylint: disable=too-many-locals, too-many-arguments, too-many-nested-blocks
+ # This function is meant to be flexible.
+
+ absolute_model_dir = os.path.join(os.path.dirname(__file__), model_dir)
+ absolute_output_dir = os.path.join(os.path.dirname(__file__), output_dir)
+
+ # Create the output dir if it does not exist
+ if not os.path.exists(absolute_output_dir):
+ os.makedirs(absolute_output_dir)
+ print(f"Created output directory {absolute_output_dir}.")
+
+ results_df = []
+ results_df_columns = [
+ "model_id",
+ "y_true",
+ "y_pred_proba",
+ "y_pred_statuses",
+ "y_pred",
+ "seed",
+ "dataset",
+ ]
+ results_txt = ""
+ for model_id, model in models.items():
+ print(f"Evaluating models {model_id}...")
+ # First, load the model
+ load_path = os.path.join(absolute_model_dir, model_id)
+ print(f"Loading model {model_id} from {load_path}...", end="")
+ model.load(load_path)
+ print("Done.")
+ results_txt += f"Model {model_id} from {load_path}:\n"
+
+ for label, (X, y, t) in [
+ ("train", (train_X, train_y, train_t[model_id])),
+ ("test", (test_X, test_y, test_t[model_id])),
+ ]:
+ if max_eval_n is not None:
+ X = X[:max_eval_n]
+ y = y[:max_eval_n]
+ t = t[:max_eval_n]
+ print(f"Evaluating model {model_id} on {label} dataset...")
+ y_pred_proba, y_pred_statuses = model.predict_proba(X, t)
+ y_pred, _ = model.predict(
+ X, t, lower_thresh, upper_thresh, y_pred_proba, y_pred_statuses
+ )
+ for i in range(y_pred_proba.shape[0]):
+ results_df.append(
+ [
+ model_id,
+ y[i],
+ y_pred_proba[i],
+ y_pred_statuses[i],
+ y_pred[i],
+ model.seed,
+ label,
+ ]
+ )
+ print("Done evaluating model.")
+
+ if viz:
+ # Visualize all images where the model was wrong
+ for i in range(y_pred_proba.shape[0]):
+ if y[i] != y_pred[i]:
+ print(f"Mispredicted: y_true: {y[i]}, y_pred: {y_pred[i]}")
+ model.visualize_img(X[i], t[i])
+
+ # Compute the summary statistics
+ txt = textwrap.indent(f"Results on {label} dataset:\n", " " * 4)
+ results_txt += txt
+ print(txt, end="")
+ for metric in [
+ accuracy_score,
+ confusion_matrix,
+ ]:
+ txt = textwrap.indent(f"{metric.__name__}:\n", " " * 8)
+ results_txt += txt
+ print(txt, end="")
+ val = metric(y, y_pred)
+ txt = textwrap.indent(f"{val}\n", " " * 12)
+ results_txt += txt
+ print(txt, end="")
+
+ results_txt += "\n"
+ print(f"Done evaluating model {model_id}.")
+
+ # Save the results
+ results_df = pd.DataFrame(results_df, columns=results_df_columns)
+ results_df.to_csv(
+ os.path.join(
+ absolute_output_dir, f"{time.strftime('%Y_%m_%d_%H_%M_%S')}_results.csv"
+ )
+ )
+ with open(
+ os.path.join(
+ absolute_output_dir, f"{time.strftime('%Y_%m_%d_%H_%M_%S')}_results.txt"
+ ),
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(results_txt)
+
+
+def main() -> None:
+ """
+ Train and test a FoodOnForkDetector as configured by the command line arguments.
+ """
+ # pylint: disable=too-many-locals
+ # This is fine since it is the main function.
+
+ # Load the arguments
+ args = read_args()
+
+ # Load the models
+ print("*" * 80)
+ print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
+ if args.seed is None:
+ seed = int(time.time())
+ else:
+ seed = args.seed
+ models = load_models(
+ args.model_classes,
+ args.model_kwargs,
+ seed,
+ args.crop_top_left,
+ args.crop_bottom_right,
+ )
+
+ # Load the dataset
+ print("*" * 80)
+ print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
+ X, y, t, camera_info = load_data(
+ models,
+ args.data_dir,
+ args.depth_topic,
+ args.color_topic,
+ args.camera_info_topic,
+ args.crop_top_left,
+ args.crop_bottom_right,
+ args.depth_min_mm,
+ args.depth_max_mm,
+ args.exclude_motion,
+ args.rosbags_select,
+ args.rosbags_skip,
+ args.temporal_window_size,
+ args.spatial_num_pixels,
+ viz=args.viz_rosbags,
+ )
+ for _, model in models.items():
+ model.camera_info = camera_info
+
+ # Do a train-test split of the dataset
+ print("*" * 80)
+ print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
+ print("Splitting the dataset...")
+ model_ids = list(models.keys())
+ split_data = train_test_split(
+ X,
+ y,
+ *[t[model_id] for model_id in model_ids],
+ train_size=args.train_set_size,
+ random_state=seed,
+ )
+ train_X = split_data[0]
+ test_X = split_data[1]
+ train_y = split_data[2]
+ test_y = split_data[3]
+ train_t = {model_id: split_data[i * 2 + 4] for i, model_id in enumerate(model_ids)}
+ test_t = {model_id: split_data[i * 2 + 5] for i, model_id in enumerate(model_ids)}
+ print(f"Done. Train size: {train_X.shape[0]}, Test size: {test_X.shape[0]}")
+
+ # Train the model
+ if not args.no_train:
+ print("*" * 80)
+ print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
+ train_models(
+ models, train_X, train_y, train_t, args.model_dir, args.viz_fit_save_dir
+ )
+
+ # Evaluate the model
+ print("*" * 80)
+ print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
+ evaluate_models(
+ models,
+ train_X,
+ test_X,
+ train_y,
+ test_y,
+ train_t,
+ test_t,
+ args.model_dir,
+ args.output_dir,
+ args.lower_thresh,
+ args.upper_thresh,
+ args.max_eval_n,
+ viz=args.viz_evaluation,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ada_feeding_perception/ada_feeding_perception/helpers.py b/ada_feeding_perception/ada_feeding_perception/helpers.py
index c1a89226..05389a10 100644
--- a/ada_feeding_perception/ada_feeding_perception/helpers.py
+++ b/ada_feeding_perception/ada_feeding_perception/helpers.py
@@ -4,23 +4,220 @@
# Standard imports
import os
import pprint
-from typing import Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
from urllib.parse import urljoin
from urllib.request import urlretrieve
# Third-party imports
import cv2
from cv_bridge import CvBridge
+import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import rclpy
from rclpy.node import Node
+
+try:
+ from rosbags.typesys.types import (
+ sensor_msgs__msg__CompressedImage as rCompressedImage,
+ sensor_msgs__msg__Image as rImage,
+ )
+except (TypeError, ModuleNotFoundError) as err:
+ rclpy.logging.get_logger("ada_feeding_perception_helpers").warn(
+ "rosbags is not installed, or a wrong version is installed (needs 0.9.19). "
+ f"Typechecking against rosbag types will not work. Error: {err}"
+ )
from sensor_msgs.msg import CompressedImage, Image
from skimage.morphology import flood_fill
+def show_normalized_depth_img(img, wait=True, window_name="img"):
+ """
+ Show the normalized depth image. Useful for debugging.
+
+ Parameters
+ ----------
+ img: npt.NDArray
+ The depth image to show.
+ wait: bool, optional
+ If True, wait for a key press before closing the window.
+ window_name: str, optional
+ The name of the window to show the image in.
+ """
+ # Show the normalized depth image
+ img_normalized = ((img - img.min()) / (img.max() - img.min()) * 255).astype("uint8")
+ cv2.imshow(window_name, img_normalized)
+ cv2.waitKey(0 if wait else 1)
+
+
+def show_3d_scatterplot(
+ pointclouds: List[npt.NDArray],
+ colors: List[npt.NDArray],
+ sizes: List[int],
+ markerstyles: List[str],
+ labels: List[str],
+ title: str,
+ mean_colors: Optional[List[npt.NDArray]] = None,
+ mean_sizes: Optional[List[int]] = None,
+ mean_markerstyles: Optional[List[str]] = None,
+):
+ """
+ Show a 3D scatterplot of the given point clouds.
+
+ Parameters
+ ----------
+ pointclouds: List[npt.NDArray]
+ The point clouds to show. Each point cloud should be a Nx3 array of
+ points.
+ colors: List[npt.NDArray]
+ The colors to use for the points. Each color should be a size 3 array of
+ colors RGB colos in range [0,1].
+ sizes: List[int]
+ The sizes to use for the points.
+ markerstyles: List[str]
+ The marker styles to use for the point clouds.
+ labels: List[str]
+ The labels to use for the point clouds.
+ title: str
+ The title of the plot.
+ """
+ # pylint: disable=too-many-arguments, too-many-locals
+ # This is meant to be a flexible function to help with debugging.
+
+ # Check that the inputs are valid
+ assert (
+ len(pointclouds)
+ == len(colors)
+ == len(sizes)
+ == len(markerstyles)
+ == len(labels)
+ )
+ if mean_colors is not None:
+ assert mean_sizes is not None
+ assert mean_markerstyles is not None
+ assert len(mean_colors) == len(mean_sizes) == len(mean_markerstyles)
+
+ # Create the plot
+ fig = plt.figure()
+ ax = fig.add_subplot(111, projection="3d")
+
+ # Plot each point cloud
+ configs = [pointclouds, colors, sizes, markerstyles, labels]
+ if mean_colors is not None:
+ configs += [mean_colors, mean_sizes, mean_markerstyles]
+ for config in zip(*configs):
+ pointcloud = config[0]
+ color = config[1]
+ size = config[2]
+ markerstyle = config[3]
+ label = config[4]
+ ax.scatter(
+ pointcloud[:, 0],
+ pointcloud[:, 1],
+ pointcloud[:, 2],
+ color=color,
+ s=size,
+ label=label,
+ marker=markerstyle,
+ )
+ if len(config) > 5:
+ mean_color = config[5]
+ mean_size = config[6]
+ mean_markerstyle = config[7]
+ mean = pointcloud.mean(axis=0)
+ ax.scatter(
+ mean[0].reshape((1, 1)),
+ mean[1].reshape((1, 1)),
+ mean[2].reshape((1, 1)),
+ color=mean_color,
+ s=mean_size,
+ label=label + " mean",
+ marker=mean_markerstyle,
+ )
+
+ # Set the title and labels
+ ax.set_title(title)
+ ax.set_xlabel("X")
+ ax.set_ylabel("Y")
+ ax.set_zlabel("Z")
+ ax.legend()
+
+ # Show the plot
+ plt.show()
+
+
+def depth_img_to_pointcloud(
+ depth_image: npt.NDArray,
+ u_offset: int,
+ v_offset: int,
+ f_x: float,
+ f_y: float,
+ c_x: float,
+ c_y: float,
+ unit_conversion: float = 1000.0,
+ transform: Optional[npt.NDArray] = None,
+) -> npt.NDArray:
+ """
+ Converts a depth image to a point cloud.
+
+ Parameters
+ ----------
+ depth_image: The depth image to convert to a point cloud.
+ u_offset: An offset to add to the column index of every pixel in the depth
+ image. This is useful if the depth image was cropped.
+ v_offset: An offset to add to the row index of every pixel in the depth
+ image. This is useful if the depth image was cropped.
+ f_x: The focal length of the camera in the x direction, using the pinhole
+ camera model.
+ f_y: The focal length of the camera in the y direction, using the pinhole
+ camera model.
+ c_x: The x-coordinate of the principal point of the camera, using the pinhole
+ camera model.
+ c_y: The y-coordinate of the principal point of the camera, using the pinhole
+ camera model.
+ unit_conversion: The depth values are divided by this constant. Defaults to 1000,
+ as RealSense returns depth in mm, but we want the pointcloud in m.
+ transform: An optional transform to apply to the point cloud. If set, this should
+ be a 4x4 matrix.
+
+ Returns
+ -------
+ pointcloud: The point cloud representation of the depth image.
+ """
+ # pylint: disable=too-many-arguments
+ # Although we could reduce it by passing in a camera matrix, I prefer to
+ # keep the arguments explicit.
+
+ # Get the pixel coordinates
+ pixel_coords = np.mgrid[: depth_image.shape[0], : depth_image.shape[1]]
+ pixel_coords[0] += v_offset
+ pixel_coords[1] += u_offset
+
+ # Mask out values outside the depth range
+ mask = depth_image > 0
+ depth_values = depth_image[mask]
+ pixel_coords = pixel_coords[:, mask]
+
+ # Convert units (e.g., mm to m)
+ depth_values = np.divide(depth_values, unit_conversion)
+
+ # Convert to 3D coordinates
+ pointcloud = np.zeros((depth_values.shape[0], 3))
+ pointcloud[:, 0] = np.multiply(pixel_coords[1] - c_x, np.divide(depth_values, f_x))
+ pointcloud[:, 1] = np.multiply(pixel_coords[0] - c_y, np.divide(depth_values, f_y))
+ pointcloud[:, 2] = depth_values
+
+ # Apply the transform if it exists
+ if transform is not None:
+ pointcloud = np.hstack((pointcloud, np.ones((pointcloud.shape[0], 1))))
+ pointcloud = np.dot(transform, pointcloud.T).T[:, :3]
+
+ return pointcloud
+
+
def ros_msg_to_cv2_image(
- msg: Union[Image, CompressedImage], bridge: Optional[CvBridge] = None
+ msg: Union[Image, rImage, CompressedImage, rCompressedImage],
+ bridge: Optional[CvBridge] = None,
) -> npt.NDArray:
"""
Convert a ROS Image or CompressedImage message to a cv2 image. By default,
@@ -39,12 +236,20 @@ def ros_msg_to_cv2_image(
is a ROS Image message. If `bridge` is None, a new CvBridge will be
created.
"""
- if isinstance(msg, Image):
- if bridge is None:
- bridge = CvBridge()
+ image_types = [Image]
+ compressed_image_types = [CompressedImage]
+ try:
+ image_types.append(rImage)
+ compressed_image_types.append(rCompressedImage)
+ except NameError as _:
+ # This only happens if rosbags wasn't imported, which is logged above.
+ pass
+ if bridge is None:
+ bridge = CvBridge()
+ if isinstance(msg, tuple(image_types)):
return bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
- if isinstance(msg, CompressedImage):
- return cv2.imdecode(np.frombuffer(msg.data, np.uint8), cv2.IMREAD_UNCHANGED)
+ if isinstance(msg, tuple(compressed_image_types)):
+ return bridge.compressed_imgmsg_to_cv2(msg, desired_encoding="passthrough")
raise ValueError("msg must be a ROS Image or CompressedImage")
@@ -76,13 +281,7 @@ def cv2_image_to_ros_msg(
if bridge is None:
bridge = CvBridge()
if compress:
- success, compressed_image = cv2.imencode(".jpg", image)
- if success:
- return CompressedImage(
- format="jpeg",
- data=compressed_image.tostring(),
- )
- raise RuntimeError("Failed to compress image")
+ return bridge.cv2_to_compressed_imgmsg(image, dst_format="jpeg")
# If we get here, we're not compressing the image
return bridge.cv2_to_imgmsg(image, encoding=encoding)
diff --git a/ada_feeding_perception/config/food_on_fork_detection.yaml b/ada_feeding_perception/config/food_on_fork_detection.yaml
new file mode 100644
index 00000000..82fc4126
--- /dev/null
+++ b/ada_feeding_perception/config/food_on_fork_detection.yaml
@@ -0,0 +1,34 @@
+# NOTE: You have to change this node name if you change the node name in the launchfile.
+food_on_fork_detection:
+ ros__parameters:
+ # The FoodOnFork class to use
+ model_class: "ada_feeding_perception.food_on_fork_detectors.FoodOnForkDistanceToNoFOFDetector"
+ # The path to load the model from. Ignored if the empty string.
+ # Should be relative to the `model_dir` parameter, specified in the launchfile.
+ model_path: "distance_no_fof_detector_with_filters.npz"
+ # Keywords to pass to the FoodOnFork class's constructor
+ model_kws:
+ - camera_matrix
+ # Keyword arguments to pass to the FoodOnFork class's constructor
+ model_kwargs:
+ camera_matrix: [614.5933227539062, 0.0, 312.1358947753906, 0.0, 614.6914672851562, 223.70831298828125, 0.0, 0.0, 1.0]
+
+ # The rate at which to detect and publish the confidence that there is food on the fork
+ rate_hz: 10.0
+ # The top-left and bottom-right corners to crop the depth image to
+ crop_top_left: [344, 272]
+ crop_bottom_right: [408, 336]
+ # The min and max depth to consider for the food on the fork
+ depth_min_mm: 310
+ depth_max_mm: 340
+
+ # The size of the temporal window for the "temporal" post-processor.
+ temporal_window_size: 5
+ # The size of the square kernel for the "spatial" post-processor.
+ spatial_num_pixels: 10
+
+ # Whether to visualize the output of the detector
+ viz: True
+ # The upper and lower thresholds for the visualization to say there is(n't) food-on-fork
+ viz_lower_thresh: 0.25
+ viz_upper_thresh: 0.75
diff --git a/ada_feeding_perception/launch/ada_feeding_perception.launch.py b/ada_feeding_perception/launch/ada_feeding_perception.launch.py
index 309240c1..be16da0f 100755
--- a/ada_feeding_perception/launch/ada_feeding_perception.launch.py
+++ b/ada_feeding_perception/launch/ada_feeding_perception.launch.py
@@ -145,4 +145,36 @@ def generate_launch_description():
)
launch_description.add_action(face_detection)
+ # Load the food-on-fork detection node
+ food_on_fork_detection_config = os.path.join(
+ ada_feeding_perception_share_dir, "config", "food_on_fork_detection.yaml"
+ )
+ food_on_fork_detection_params = {}
+ food_on_fork_detection_params["model_dir"] = ParameterValue(
+ os.path.join(ada_feeding_perception_share_dir, "model"), value_type=str
+ )
+ food_on_fork_detection_remappings = [
+ ("~/food_on_fork_detection", "/food_on_fork_detection"),
+ ("~/food_on_fork_detection_img", "/food_on_fork_detection_img"),
+ ("~/toggle_food_on_fork_detection", "/toggle_food_on_fork_detection"),
+ (
+ "~/aligned_depth",
+ PythonExpression(
+ expression=[
+ "'",
+ prefix,
+ "/camera/aligned_depth_to_color/image_raw'",
+ ]
+ ),
+ ),
+ ]
+ food_on_fork_detection = Node(
+ package="ada_feeding_perception",
+ name="food_on_fork_detection",
+ executable="food_on_fork_detection",
+ parameters=[food_on_fork_detection_config, food_on_fork_detection_params],
+ remappings=realsense_remappings + food_on_fork_detection_remappings,
+ )
+ launch_description.add_action(food_on_fork_detection)
+
return launch_description
diff --git a/ada_feeding_perception/model/distance_no_fof_detector_with_filters.npz b/ada_feeding_perception/model/distance_no_fof_detector_with_filters.npz
new file mode 100644
index 00000000..70a91b3d
Binary files /dev/null and b/ada_feeding_perception/model/distance_no_fof_detector_with_filters.npz differ
diff --git a/ada_feeding_perception/package.xml b/ada_feeding_perception/package.xml
index ed90ffe8..9ab76b62 100644
--- a/ada_feeding_perception/package.xml
+++ b/ada_feeding_perception/package.xml
@@ -18,12 +18,17 @@
python3-matplotlib
python3-numpy
python3-opencv
+ python3-pandas
+
+
python3-scikit-spatial-pip
python3-scipy
python3-shapely
python3-skimage
python3-torch
python3-torchvision
+ python-transforms3d-pip
ament_python
diff --git a/ada_feeding_perception/setup.py b/ada_feeding_perception/setup.py
index 8a002894..e3edf700 100644
--- a/ada_feeding_perception/setup.py
+++ b/ada_feeding_perception/setup.py
@@ -24,7 +24,7 @@
# Include all model files.
(
os.path.join("share", package_name, "model"),
- glob(os.path.join("model", "*")),
+ glob(os.path.join("model", "*.*")),
),
# Include all config files.
(
@@ -46,6 +46,7 @@
tests_require=["pytest"],
entry_points={
"console_scripts": [
+ "food_on_fork_detection = ada_feeding_perception.food_on_fork_detection:main",
"republisher = ada_feeding_perception.republisher:main",
"segment_from_point = ada_feeding_perception.segment_from_point:main",
"test_segment_from_point = ada_feeding_perception.test_segment_from_point:main",