Skip to content

Commit

Permalink
Fix remote code execution due to pickle (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
timonegk authored Apr 19, 2024
2 parents a2ca5da + 62705b8 commit 11e9dd1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 33 deletions.
8 changes: 7 additions & 1 deletion udp_bridge/message_handler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import io
import pickle
import zlib

from udp_bridge.aes_helper import AESCipher


class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
raise pickle.UnpicklingError("pickle loading restricted to base types")


class MessageHandler:
def __init__(self, encryption_key: str | None):
self.cipher = AESCipher(encryption_key)
Expand All @@ -15,4 +21,4 @@ def encrypt_and_encode(self, data: dict) -> bytes:
def decrypt_and_decode(self, msg: bytes):
decrypted_msg = self.cipher.decrypt(msg)
binary_msg = zlib.decompress(decrypted_msg)
return pickle.loads(binary_msg)
return RestrictedUnpickler(io.BytesIO(binary_msg)).load()
5 changes: 4 additions & 1 deletion udp_bridge/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import rclpy
from rclpy.node import Node
from rclpy.qos import DurabilityPolicy, QoSProfile
from rclpy.serialization import deserialize_message
from rosidl_runtime_py.utilities import get_message

from udp_bridge.message_handler import MessageHandler

Expand Down Expand Up @@ -47,7 +49,8 @@ def handle_message(self, msg: bytes):
"""
try:
deserialized_msg = self.message_handler.decrypt_and_decode(msg)
data = deserialized_msg.get("data")
msg_type_name = deserialized_msg.get("msg_type_name")
data = deserialize_message(deserialized_msg.get("data"), get_message(msg_type_name))
topic: str = deserialized_msg.get("topic")
hostname: str = deserialized_msg.get("hostname")
latched: bool = deserialized_msg.get("latched")
Expand Down
65 changes: 34 additions & 31 deletions udp_bridge/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from rclpy.logging import LoggingSeverity
from rclpy.node import Node
from rclpy.qos import DurabilityPolicy, QoSProfile
from rclpy.serialization import serialize_message
from rclpy.subscription import Subscription
from rclpy.timer import Timer
from ros2topic.api import get_msg_class, get_topic_names
from rosidl_runtime_py.utilities import get_message

from udp_bridge.message_handler import MessageHandler

Expand All @@ -36,6 +37,7 @@ def __init__(self, topic: str, queue_size: int, message_handler: MessageHandler,
self.message_handler: MessageHandler = message_handler
self.node: Node = node
self.timer: Timer | None = None
self.msg_type_name: str = None

self.__subscriber: Subscription | None = None
self.__latched_subscriber: Subscription | None = None
Expand All @@ -53,43 +55,44 @@ def __subscribe(self, backoff=1.0):
self.timer.cancel()

data_class = None
topics = get_topic_names(node=self.node)
topic = next(filter(lambda t: t == self.topic, topics), None)

if topic is not None:
data_class = get_msg_class(self.node, topic)

if data_class is not None:
# topic is known
self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}")
# find out if topic is latched / transient local
publisher_infos = self.node.get_publishers_info_by_topic(topic)
latched = any(info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos)
self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1)
if latched:
self.__latched_subscriber = self.node.create_subscription(
data_class,
self.topic,
lambda msg: self.__message_callback(msg, latched=True),
QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL),
for topic, msg_type_names in self.node.get_topic_names_and_types():
if topic == self.topic:
self.msg_type_name = msg_type_names[0]
data_class = get_message(self.msg_type_name)
# topic is known
self.node.get_logger().debug(f"Want to subscribe to topic {self.topic}")
# find out if topic is latched / transient local
publisher_infos = self.node.get_publishers_info_by_topic(topic)
latched = any(
info.qos_profile.durability == DurabilityPolicy.TRANSIENT_LOCAL for info in publisher_infos
)
self.node.get_logger().debug(f"Subscribed to topic {self.topic}")
self.__subscriber = self.node.create_subscription(data_class, self.topic, self.__message_callback, 1)
if latched:
self.__latched_subscriber = self.node.create_subscription(
data_class,
self.topic,
lambda msg: self.__message_callback(msg, latched=True),
QoSProfile(depth=1, durability=DurabilityPolicy.TRANSIENT_LOCAL),
)
self.node.get_logger().debug(f"Subscribed to topic {self.topic}")
return

# topic is not yet known
if backoff > 10:
logging_severity = LoggingSeverity.WARN
else:
# topic is not yet known
if backoff > 10:
logging_severity = LoggingSeverity.WARN
else:
logging_severity = LoggingSeverity.DEBUG
self.node.get_logger().log(
f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds", logging_severity
)
self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2))
logging_severity = LoggingSeverity.DEBUG
self.node.get_logger().log(
f"Topic {self.topic} is not yet known. Retrying in {backoff} seconds", logging_severity
)
self.timer = self.node.create_timer(backoff, lambda: self.__subscribe(backoff * 1.2))

def __message_callback(self, data, latched=False):
encrypted_msg = self.message_handler.encrypt_and_encode(
{
"data": data,
"data": serialize_message(data),
"topic": self.topic,
"msg_type_name": self.msg_type_name,
"hostname": HOSTNAME,
"latched": latched,
}
Expand Down

0 comments on commit 11e9dd1

Please sign in to comment.