Skip to content

Commit

Permalink
Support latched topics in udp bridge
Browse files Browse the repository at this point in the history
  • Loading branch information
timonegk committed Apr 17, 2024
1 parent 818e882 commit 1082ec7
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 21 deletions.
4 changes: 2 additions & 2 deletions udp_bridge/aes_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def encrypt(self, message: str) -> bytes:
if self.encryption_key is None:
return bytes(message, encoding="UTF-8")

return Fernet(key=self.encryption_key).encrypt(bytes(message, encoding="UTF-8"))
return Fernet(key=self.encryption_key).encrypt(message)

def decrypt(self, enc: bytes) -> str:
if len(enc) == 0:
raise ValueError("Cannot decrypt empty data")
if self.encryption_key is None:
return str(enc, encoding="UTF-8")

return str(Fernet(key=self.encryption_key).decrypt(enc), encoding="UTF-8")
return Fernet(key=self.encryption_key).decrypt(enc)
9 changes: 4 additions & 5 deletions udp_bridge/message_handler.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import base64
import pickle
import zlib

from udp_bridge.aes_helper import AESCipher


class MessageHandler:
PACKAGE_DELIMITER = b"\xff\xff\xff"

def __init__(self, encryption_key: str | None):
self.cipher = AESCipher(encryption_key)

def encrypt_and_encode(self, data: dict) -> bytes:
serialized_data = base64.b64encode(pickle.dumps(data, pickle.HIGHEST_PROTOCOL)).decode("ASCII")
serialized_data = zlib.compress(pickle.dumps(data, pickle.HIGHEST_PROTOCOL))
return self.cipher.encrypt(serialized_data)

def dencrypt_and_decode(self, msg: bytes):
def decrypt_and_decode(self, msg: bytes):
decrypted_msg = self.cipher.decrypt(msg)
binary_msg = base64.b64decode(decrypted_msg)
binary_msg = zlib.decompress(decrypted_msg)
return pickle.loads(binary_msg)
12 changes: 4 additions & 8 deletions udp_bridge/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@ def recv_message(self):
"""
Receive a message from the network, process it and publish it into ROS
"""
acc = b""
while rclpy.ok():
try:
acc += self.sock.recv(10240)

if acc[-3:] == MessageHandler.PACKAGE_DELIMITER:
self.handle_message(acc[:-3])
acc = b""

# 65535 is the upper limit for the size because of network properties
msg = self.sock.recv(65535)
self.handle_message(msg)
except socket.timeout:
pass

Expand All @@ -49,7 +45,7 @@ def handle_message(self, msg: bytes):
Handle a new message which came in from the socket
"""
try:
deserialized_msg = self.message_handler.dencrypt_and_decode(msg)
deserialized_msg = self.message_handler.decrypt_and_decode(msg)
data = deserialized_msg.get("data")
topic: str = deserialized_msg.get("topic")
hostname: str = deserialized_msg.get("hostname")
Expand Down
30 changes: 24 additions & 6 deletions udp_bridge/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import socket
from queue import Empty, Full, Queue
import zlib

import rclpy
from bitbots_utils.utils import get_parameters_from_other_node
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node
from rclpy.subscription import Subscription
from rclpy.timer import Timer
from rclpy.qos import DurabilityPolicy, QoSProfile
from ros2topic.api import get_msg_class, get_topic_names

from udp_bridge.message_handler import MessageHandler
Expand Down Expand Up @@ -36,6 +38,7 @@ def __init__(self, topic: str, queue_size: int, message_handler: MessageHandler,
self.timer: Timer | None = None

self.__subscriber: Subscription | None = None
self.__latched_subscriber: Subscription | None = None
self.__subscribe()

def __subscribe(self, backoff=1.0):
Expand All @@ -55,19 +58,28 @@ def __subscribe(self, backoff=1.0):

if topic is not None:
data_class = get_msg_class(self.node, topic)
self.node.get_logger().info(str(data_class))

if data_class is not None:
# topic is known
self.node.get_logger().info(f"Want to subscribe to topic {self.topic}")
self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}")
# find out if topic is latched / transient local
publisher_infos = self.node.get_publishers_info_by_topic(topic)
latched = any(info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos)
self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1)
self.node.get_logger().info(f"Subscribed to topic {self.topic}")
if latched:
self.__latched_subscriber = self.node.create_subscription(
data_class,
self.topic,
lambda msg: self.__message_callback(msg, latched=True),
QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL)
)
self.node.get_logger().debug(f"Subscribed to topic {self.topic}")
else:
# topic is not yet known
self.node.get_logger().info(f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds")
self.node.get_logger().debug(f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds")
self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2))

def __message_callback(self, data):
def __message_callback(self, data, latched=False):
encrypted_msg = self.message_handler.encrypt_and_encode(
{
"data": data,
Expand All @@ -81,6 +93,11 @@ def __message_callback(self, data):
except Full:
self.node.get_logger().warn(f"Could not enqueue new message of topic {self.topic}. Queue full.")

# for latched messages, republish them every ten seconds because we cannot latch on the other side
if latched:
if self.timer:
self.timer.cancel()
self.timer = self.node.create_timer(10.0, lambda: self.__message_callback(data, latched=True))

# @TODO: replace by usage of https://github.com/PickNikRobotics/generate_parameter_library
def validate_params(node: Node) -> bool:
Expand Down Expand Up @@ -133,6 +150,7 @@ def __init__(self, node: Node):
self.sock = self.setup_udp_socket()

topics: list[str] = node.get_parameter("topics").value
print(topics)
max_queue_size: int = node.get_parameter("sender_queue_max_size").value
message_handler = self.setup_message_handler()
self.subscribers: list[AutoSubscriber] = list(
Expand All @@ -157,7 +175,7 @@ def send_messages_in_queue(self):
data = subscriber.queue.get_nowait()

try:
self.sock.sendto(data + MessageHandler.PACKAGE_DELIMITER, (self.target, self.port))
self.sock.sendto(data, (self.target, self.port))
except Exception as e:
self.node.get_logger().error(
f"Could not send data of topic {subscriber.topic} to {self.target} with error {str(e)}"
Expand Down

0 comments on commit 1082ec7

Please sign in to comment.