Skip to content

Commit

Permalink
Added TTD to prediction and traffic light detection. (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
ICGog authored Mar 30, 2021
1 parent 64d0d9d commit 342c32e
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 30 deletions.
4 changes: 2 additions & 2 deletions lincoln.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def create_data_flow():
# The right camera is more likely to contain the traffic lights.
traffic_lights_stream = \
pylot.operator_creator.add_traffic_light_detector(
right_camera_stream)
right_camera_stream, time_to_decision_loop_stream)
# Adds operator that finds the world locations of the traffic lights.
traffic_lights_stream = \
pylot.operator_creator.add_obstacle_location_finder(
Expand Down Expand Up @@ -160,7 +160,7 @@ def create_data_flow():

if FLAGS.prediction:
prediction_stream = pylot.operator_creator.add_linear_prediction(
obstacles_tracking_stream)
obstacles_tracking_stream, time_to_decision_loop_stream)
else:
prediction_stream = obstacles_stream

Expand Down
9 changes: 5 additions & 4 deletions pylot.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def driver():
traffic_lights_stream, tl_camera_stream = \
pylot.component_creator.add_traffic_light_detection(
tl_transform, vehicle_id_stream, release_sensor_stream,
pose_stream, depth_stream, ground_traffic_lights_stream)
pose_stream, depth_stream, ground_traffic_lights_stream,
time_to_decision_loop_stream)

lane_detection_stream = pylot.component_creator.add_lane_detection(
center_camera_stream, pose_stream, open_drive_stream)
Expand Down Expand Up @@ -159,9 +160,9 @@ def driver():

prediction_stream, prediction_camera_stream, notify_prediction_stream = \
pylot.component_creator.add_prediction(
obstacles_tracking_stream, vehicle_id_stream, transform,
release_sensor_stream, pose_stream, point_cloud_stream,
lidar_setup)
obstacles_tracking_stream, vehicle_id_stream,
time_to_decision_loop_stream, transform, release_sensor_stream,
pose_stream, point_cloud_stream, lidar_setup)
if prediction_stream is None:
prediction_stream = obstacles_stream
if notify_prediction_stream:
Expand Down
15 changes: 9 additions & 6 deletions pylot/component_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def add_traffic_light_detection(tl_transform,
release_sensor_stream,
pose_stream=None,
depth_stream=None,
ground_traffic_lights_stream=None):
ground_traffic_lights_stream=None,
time_to_decision_stream=None):
"""Adds traffic light detection operators.
The traffic light detectors use a camera with a narrow field of view.
Expand Down Expand Up @@ -174,7 +175,7 @@ def add_traffic_light_detection(tl_transform,
logger.debug('Using traffic light detection...')
traffic_lights_stream = \
pylot.operator_creator.add_traffic_light_detector(
tl_camera_stream)
tl_camera_stream, time_to_decision_stream)
# Adds operator that finds the world locations of the traffic lights.
traffic_lights_stream = \
pylot.operator_creator.add_obstacle_location_finder(
Expand Down Expand Up @@ -432,8 +433,9 @@ def add_segmentation(center_camera_stream, ground_segmented_stream=None):

def add_prediction(obstacles_tracking_stream,
vehicle_id_stream,
camera_transform,
release_sensor_stream,
time_to_decision_stream,
camera_transform=None,
release_sensor_stream=None,
pose_stream=None,
point_cloud_stream=None,
lidar_setup=None):
Expand Down Expand Up @@ -465,13 +467,14 @@ def add_prediction(obstacles_tracking_stream,
if FLAGS.prediction_type == 'linear':
logger.debug('Using linear prediction...')
prediction_stream = pylot.operator_creator.add_linear_prediction(
obstacles_tracking_stream)
obstacles_tracking_stream, time_to_decision_stream)
elif FLAGS.prediction_type == 'r2p2':
logger.debug('Using R2P2 prediction...')
assert point_cloud_stream is not None
assert lidar_setup is not None
prediction_stream = pylot.operator_creator.add_r2p2_prediction(
point_cloud_stream, obstacles_tracking_stream, lidar_setup)
point_cloud_stream, obstacles_tracking_stream,
time_to_decision_stream, lidar_setup)
else:
raise ValueError('Unexpected prediction_type {}'.format(
FLAGS.prediction_type))
Expand Down
26 changes: 14 additions & 12 deletions pylot/operator_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,19 @@ def add_control_evaluation(pose_stream,
[pose_stream, waypoints_stream], FLAGS)


def add_traffic_light_detector(traffic_light_camera_stream):
def add_traffic_light_detector(traffic_light_camera_stream,
time_to_decision_stream):
from pylot.perception.detection.traffic_light_det_operator import \
TrafficLightDetOperator
op_config = erdos.OperatorConfig(name='traffic_light_detector_operator',
flow_watermarks=False,
log_file_name=FLAGS.log_file_name,
csv_log_file_name=FLAGS.csv_log_file_name,
profile_file_name=FLAGS.profile_file_name)
[traffic_lights_stream] = erdos.connect(TrafficLightDetOperator, op_config,
[traffic_light_camera_stream],
FLAGS)
[traffic_lights_stream
] = erdos.connect(TrafficLightDetOperator, op_config,
[traffic_light_camera_stream, time_to_decision_stream],
FLAGS)
return traffic_lights_stream


Expand Down Expand Up @@ -345,30 +347,30 @@ def add_segmentation_decay(ground_segmented_stream,
return iou_stream


def add_linear_prediction(tracking_stream):
def add_linear_prediction(tracking_stream, time_to_decision_stream):
from pylot.prediction.linear_predictor_operator import \
LinearPredictorOperator
op_config = erdos.OperatorConfig(name='linear_prediction_operator',
log_file_name=FLAGS.log_file_name,
csv_log_file_name=FLAGS.csv_log_file_name,
profile_file_name=FLAGS.profile_file_name)
[prediction_stream] = erdos.connect(LinearPredictorOperator, op_config,
[tracking_stream], FLAGS)
[prediction_stream
] = erdos.connect(LinearPredictorOperator, op_config,
[tracking_stream, time_to_decision_stream], FLAGS)
return prediction_stream


def add_r2p2_prediction(point_cloud_stream, obstacles_tracking_stream,
lidar_setup):
time_to_decision_stream, lidar_setup):
from pylot.prediction.r2p2_predictor_operator import \
R2P2PredictorOperator
op_config = erdos.OperatorConfig(name='r2p2_prediction_operator',
log_file_name=FLAGS.log_file_name,
csv_log_file_name=FLAGS.csv_log_file_name,
profile_file_name=FLAGS.profile_file_name)
[prediction_stream
] = erdos.connect(R2P2PredictorOperator, op_config,
[point_cloud_stream, obstacles_tracking_stream], FLAGS,
lidar_setup)
[prediction_stream] = erdos.connect(R2P2PredictorOperator, op_config, [
point_cloud_stream, obstacles_tracking_stream, time_to_decision_stream
], FLAGS, lidar_setup)
return prediction_stream


Expand Down
2 changes: 1 addition & 1 deletion pylot/perception/detection/detection_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def destroy(self):
self._obstacles_stream.send(
erdos.WatermarkMessage(erdos.Timestamp(is_top=True)))

def on_time_to_decision_update(self, msg):
def on_time_to_decision_update(self, msg: erdos.Message):
self._logger.debug('@{}: {} received ttd update {}'.format(
msg.timestamp, self.config.name, msg))

Expand Down
9 changes: 8 additions & 1 deletion pylot/perception/detection/traffic_light_det_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class TrafficLightDetOperator(erdos.Operator):
flags (absl.flags): Object to be used to access absl flags.
"""
def __init__(self, camera_stream: erdos.ReadStream,
time_to_decision_stream: erdos.ReadStream,
traffic_lights_stream: erdos.WriteStream, flags):
# Register a callback on the camera input stream.
camera_stream.add_callback(self.on_frame, [traffic_lights_stream])
time_to_decision_stream.add_callback(self.on_time_to_decision_update)
self._logger = erdos.utils.setup_logging(self.config.name,
self.config.log_file_name)
self._flags = flags
Expand Down Expand Up @@ -78,7 +80,8 @@ def __init__(self, camera_stream: erdos.ReadStream,
self.__run_model(np.zeros((108, 192, 3)))

@staticmethod
def connect(camera_stream: erdos.ReadStream):
def connect(camera_stream: erdos.ReadStream,
time_to_decision_stream: erdos.ReadStream):
"""Connects the operator to other streams.
Args:
Expand All @@ -98,6 +101,10 @@ def destroy(self):
self._traffic_lights_stream.send(
erdos.WatermarkMessage(erdos.Timestamp(is_top=True)))

def on_time_to_decision_update(self, msg: erdos.Message):
self._logger.debug('@{}: {} received ttd update {}'.format(
msg.timestamp, self.config.name, msg))

@erdos.profile_method()
def on_frame(self, msg: erdos.Message,
traffic_lights_stream: erdos.WriteStream):
Expand Down
9 changes: 8 additions & 1 deletion pylot/prediction/linear_predictor_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ class LinearPredictorOperator(erdos.Operator):
flags (absl.flags): Object to be used to access absl flags.
"""
def __init__(self, tracking_stream: ReadStream,
time_to_decision_stream: ReadStream,
linear_prediction_stream: WriteStream, flags):
tracking_stream.add_callback(self.generate_predicted_trajectories,
[linear_prediction_stream])
time_to_decision_stream.add_callback(self.on_time_to_decision_update)
self._logger = erdos.utils.setup_logging(self.config.name,
self.config.log_file_name)
self._flags = flags

@staticmethod
def connect(tracking_stream: ReadStream):
def connect(tracking_stream: ReadStream,
time_to_decision_stream: ReadStream):
"""Connects the operator to other streams.
Args:
Expand All @@ -52,6 +55,10 @@ def connect(tracking_stream: ReadStream):
def destroy(self):
self._logger.warn('destroying {}'.format(self.config.name))

def on_time_to_decision_update(self, msg):
self._logger.debug('@{}: {} received ttd update {}'.format(
msg.timestamp, self.config.name, msg))

@erdos.profile_method()
def generate_predicted_trajectories(self, msg: Message,
linear_prediction_stream: WriteStream):
Expand Down
9 changes: 8 additions & 1 deletion pylot/prediction/r2p2_predictor_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class R2P2PredictorOperator(erdos.Operator):
"""
def __init__(self, point_cloud_stream: erdos.ReadStream,
tracking_stream: erdos.ReadStream,
time_to_decision_stream: erdos.ReadStream,
prediction_stream: erdos.WriteStream, flags, lidar_setup):
print("WARNING: R2P2 predicts only vehicle trajectories")
self._logger = erdos.utils.setup_logging(self.config.name,
Expand All @@ -52,6 +53,7 @@ def __init__(self, point_cloud_stream: erdos.ReadStream,

point_cloud_stream.add_callback(self.on_point_cloud_update)
tracking_stream.add_callback(self.on_trajectory_update)
time_to_decision_stream.add_callback(self.on_time_to_decision_update)
erdos.add_watermark_callback([point_cloud_stream, tracking_stream],
[prediction_stream], self.on_watermark)

Expand All @@ -62,7 +64,8 @@ def __init__(self, point_cloud_stream: erdos.ReadStream,

@staticmethod
def connect(point_cloud_stream: erdos.ReadStream,
tracking_stream: erdos.ReadStream):
tracking_stream: erdos.ReadStream,
time_to_decision_stream: erdos.ReadStream):
prediction_stream = erdos.WriteStream()
return [prediction_stream]

Expand Down Expand Up @@ -209,3 +212,7 @@ def on_trajectory_update(self, msg: erdos.Message):
self._logger.debug('@{}: received trajectories message'.format(
msg.timestamp))
self._tracking_msgs.append(msg)

def on_time_to_decision_update(self, msg: erdos.Message):
self._logger.debug('@{}: {} received ttd update {}'.format(
msg.timestamp, self.config.name, msg))
7 changes: 5 additions & 2 deletions pylot/simulation/challenge/ERDOSAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def create_data_flow():
# the small fov.
traffic_lights_stream = \
pylot.operator_creator.add_traffic_light_detector(
camera_streams[TL_CAMERA_NAME])
camera_streams[TL_CAMERA_NAME], time_to_decision_loop_stream)
# Adds an operator that finds the world location of the traffic lights.
# The operator synchronizes LiDAR point cloud readings with camera
# frames, and uses them to compute the depth to traffic light bounding
Expand Down Expand Up @@ -326,7 +326,10 @@ def create_data_flow():
# The agent uses a linear predictor to compute future trajectories
# of the other agents.
prediction_stream, _, _ = pylot.component_creator.add_prediction(
obstacles_tracking_stream, vehicle_id_stream, None, None, pose_stream)
obstacles_tracking_stream,
vehicle_id_stream,
time_to_decision_loop_stream,
pose_stream=pose_stream)

# Adds a planner to the agent. The planner receives the pose of
# the ego-vehicle, detected traffic lights, predictions for other
Expand Down

0 comments on commit 342c32e

Please sign in to comment.