Skip to content

Commit

Permalink
Improved flag validation and added more type hints (#178)
Browse files Browse the repository at this point in the history
* More type hints.

* Updated the scripts to ensure ENV variables are set.

* Improved flag validation.

* Ensure end-to-end tests are not executed by pytest.
  • Loading branch information
ICGog authored Mar 29, 2021
1 parent 256d8c5 commit 64d0d9d
Show file tree
Hide file tree
Showing 30 changed files with 410 additions and 208 deletions.
3 changes: 1 addition & 2 deletions pylot.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def driver():
control_loop_stream.set(control_stream)

pylot.component_creator.add_evaluation(vehicle_id_stream, pose_stream,
imu_stream, pose_stream_for_control,
waypoints_stream_for_control)
imu_stream)

time_to_decision_stream = pylot.operator_creator.add_time_to_decision(
pose_stream, obstacles_stream)
Expand Down
87 changes: 63 additions & 24 deletions pylot/component_creator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from absl import flags

import pylot.operator_creator
Expand All @@ -6,6 +8,8 @@

FLAGS = flags.FLAGS

logger = logging.getLogger(__name__)


def add_obstacle_detection(center_camera_stream,
center_camera_setup=None,
Expand Down Expand Up @@ -65,16 +69,19 @@ def add_obstacle_detection(center_camera_stream,
obstacles_stream_wo_depth = None
if any('efficientdet' in model
for model in FLAGS.obstacle_detection_model_names):
logger.debug('Using EfficientDet obstacle detector...')
obstacles_streams = pylot.operator_creator.\
add_efficientdet_obstacle_detection(
center_camera_stream, time_to_decision_stream)
add_efficientdet_obstacle_detection(
center_camera_stream, time_to_decision_stream)
obstacles_stream_wo_depth = obstacles_streams[0]
else:
logger.debug('Using obstacle detector...')
# TODO: Only returns the first obstacles stream.
obstacles_streams = pylot.operator_creator.add_obstacle_detection(
center_camera_stream, time_to_decision_stream)
obstacles_stream_wo_depth = obstacles_streams[0]
if FLAGS.planning_type == 'waypoint':
logger.debug('Adding obstacle location finder...')
# Adds an operator that finds the world locations of the obstacles.
obstacles_stream = \
pylot.operator_creator.add_obstacle_location_finder(
Expand All @@ -89,11 +96,13 @@ def add_obstacle_detection(center_camera_stream,
and ground_obstacles_stream is not None
and ground_speed_limit_signs_stream is not None
and ground_stop_signs_stream is not None)
logger.debug('Using perfect obstacle detector...')
perfect_obstacles_stream = pylot.operator_creator.add_perfect_detector(
depth_camera_stream, center_camera_stream, segmented_camera_stream,
pose_stream, ground_obstacles_stream,
ground_speed_limit_signs_stream, ground_stop_signs_stream)
if FLAGS.evaluate_obstacle_detection:
logger.debug('Adding obstacle detection evaluation...')
pylot.operator_creator.add_detection_evaluation(
obstacles_stream_wo_depth,
perfect_obstacles_stream,
Expand All @@ -107,9 +116,11 @@ def add_obstacle_detection(center_camera_stream,
matching_policy='ceil',
name='timely_detection_eval_operator')
if FLAGS.perfect_obstacle_detection:
logger.debug('Using perfect obstacle detector...')
obstacles_stream = perfect_obstacles_stream

if FLAGS.simulator_obstacle_detection:
logger.debug('Using ground obstacles from the simulator...')
obstacles_stream = ground_obstacles_stream

return obstacles_stream, perfect_obstacles_stream
Expand Down Expand Up @@ -147,6 +158,7 @@ def add_traffic_light_detection(tl_transform,
"""
tl_camera_stream = None
if FLAGS.traffic_light_detection or FLAGS.perfect_traffic_light_detection:
logger.debug('Adding traffic light camera...')
# Only add the TL camera if traffic light detection is enabled.
tl_camera_setup = RGBCameraSetup('traffic_light_camera',
FLAGS.camera_image_width,
Expand All @@ -159,6 +171,7 @@ def add_traffic_light_detection(tl_transform,

traffic_lights_stream = None
if FLAGS.traffic_light_detection:
logger.debug('Using traffic light detection...')
traffic_lights_stream = \
pylot.operator_creator.add_traffic_light_detector(
tl_camera_stream)
Expand All @@ -171,6 +184,7 @@ def add_traffic_light_detection(tl_transform,
if FLAGS.perfect_traffic_light_detection:
assert (pose_stream is not None
and ground_traffic_lights_stream is not None)
logger.debug('Using perfect traffic light detection...')
# Add segmented and depth cameras with fov 45. These cameras are needed
# by the perfect traffic light detector.
tl_depth_camera_setup = DepthCameraSetup('traffic_light_depth_camera',
Expand All @@ -197,6 +211,7 @@ def add_traffic_light_detection(tl_transform,
pose_stream)

if FLAGS.simulator_traffic_light_detection:
logger.debug('Using ground traffic lights from the simulator...')
traffic_lights_stream = ground_traffic_lights_stream

return traffic_lights_stream, tl_camera_stream
Expand Down Expand Up @@ -227,12 +242,15 @@ def add_depth(transform, vehicle_id_stream, center_camera_setup,
"""
depth_stream = None
if FLAGS.depth_estimation:
logger.debug('Adding left and right cameras for depth estimation...')
(left_camera_stream,
right_camera_stream) = pylot.operator_creator.add_left_right_cameras(
transform, vehicle_id_stream)
logger.debug('Using camera depth estimation...')
depth_stream = pylot.operator_creator.add_depth_estimation(
left_camera_stream, right_camera_stream, center_camera_setup)
elif FLAGS.perfect_depth_estimation:
if FLAGS.perfect_depth_estimation:
logger.debug('Using perfect depth estimation...')
depth_stream = depth_camera_stream
return depth_stream

Expand Down Expand Up @@ -263,15 +281,22 @@ def add_lane_detection(center_camera_stream,
lane_detection_stream = None
if FLAGS.lane_detection:
if FLAGS.lane_detection_type == 'canny':
logger.debug('Using Canny Edge lane detector...')
lane_detection_stream = \
pylot.operator_creator.add_canny_edge_lane_detection(
center_camera_stream)
elif FLAGS.lane_detection_type == 'lanenet':
logger.debug('Using Lanenet lane detector...')
lane_detection_stream = \
pylot.operator_creator.add_lanenet_detection(
center_camera_stream)
elif FLAGS.perfect_lane_detection:
assert pose_stream is not None
else:
raise ValueError('Unexpected lane detection type {}'.format(
FLAGS.lane_detection_type))
if FLAGS.perfect_lane_detection:
assert pose_stream is not None, \
'Cannot added perfect lane detection without a post stream'
logger.debug('Using perfect lane detector...')
lane_detection_stream = \
pylot.operator_creator.add_perfect_lane_detector(
pose_stream, open_drive_stream, center_camera_stream)
Expand Down Expand Up @@ -321,22 +346,26 @@ def add_obstacle_tracking(center_camera_stream,
obstacles_tracking_stream = None
if FLAGS.obstacle_tracking:
if FLAGS.tracker_type == 'center_track':
logger.debug('Using CenterTrack obstacle tracker...')
obstacles_wo_history_tracking_stream = \
pylot.operator_creator.add_center_track_tracking(
center_camera_stream, center_camera_setup)
else:
logger.debug('Using obstacle tracker...')
obstacles_wo_history_tracking_stream = \
pylot.operator_creator.add_obstacle_tracking(
obstacles_stream,
center_camera_stream,
time_to_decision_stream)
logger.debug('Adding operator to compute obstacle location history...')
obstacles_tracking_stream = \
pylot.operator_creator.add_obstacle_location_history(
obstacles_wo_history_tracking_stream, depth_stream,
pose_stream, center_camera_setup)
elif FLAGS.perfect_obstacle_tracking:
if FLAGS.perfect_obstacle_tracking:
assert (pose_stream is not None
and ground_obstacles_stream is not None)
logger.debug('Using perfect obstacle tracker...')
obstacles_tracking_stream = \
pylot.operator_creator.add_perfect_tracking(
vehicle_id_stream, ground_obstacles_stream, pose_stream)
Expand All @@ -347,6 +376,7 @@ def add_obstacle_tracking(center_camera_stream,
# stream is generated by a perfect detector.
# Note: the tracker eval operator cannot compute accuracy
# if the obstacles do not contain 2D bounding boxes.
logger.debug('Adding obstacle tracking evaluation...')
pylot.operator_creator.add_tracking_evaluation(
obstacles_wo_history_tracking_stream,
ground_obstacles_stream,
Expand Down Expand Up @@ -384,13 +414,17 @@ def add_segmentation(center_camera_stream, ground_segmented_stream=None):
"""
segmented_stream = None
if FLAGS.segmentation:
logger.debug('Using semantic segmentation...')
segmented_stream = pylot.operator_creator.add_segmentation(
center_camera_stream)
if FLAGS.evaluate_segmentation:
assert ground_segmented_stream is not None
assert ground_segmented_stream is not None, \
"Cannot evaluate segmentation without ground truth"
logger.debug('Adding semantic segmentation evaluation...')
pylot.operator_creator.add_segmentation_evaluation(
ground_segmented_stream, segmented_stream)
elif FLAGS.perfect_segmentation:
if FLAGS.perfect_segmentation:
logger.debug('Using perfect semantic segmentation...')
assert ground_segmented_stream is not None
return ground_segmented_stream
return segmented_stream
Expand Down Expand Up @@ -429,9 +463,11 @@ def add_prediction(obstacles_tracking_stream,
notify_reading_stream = None
if FLAGS.prediction:
if FLAGS.prediction_type == 'linear':
logger.debug('Using linear prediction...')
prediction_stream = pylot.operator_creator.add_linear_prediction(
obstacles_tracking_stream)
elif FLAGS.prediction_type == 'r2p2':
logger.debug('Using R2P2 prediction...')
assert point_cloud_stream is not None
assert lidar_setup is not None
prediction_stream = pylot.operator_creator.add_r2p2_prediction(
Expand All @@ -441,9 +477,11 @@ def add_prediction(obstacles_tracking_stream,
FLAGS.prediction_type))
if FLAGS.evaluate_prediction:
assert pose_stream is not None
logger.debug('Adding prediction evaluation...')
pylot.operator_creator.add_prediction_evaluation(
pose_stream, obstacles_tracking_stream, prediction_stream)
if FLAGS.visualize_prediction:
logger.debug('Adding for prediction evaluation...')
# Add bird's eye camera.
top_down_transform = pylot.utils.get_top_down_transform(
camera_transform, FLAGS.top_down_camera_altitude)
Expand All @@ -454,6 +492,8 @@ def add_prediction(obstacles_tracking_stream,
notify_reading_stream) = pylot.operator_creator.add_camera_driver(
top_down_seg_camera_setup, vehicle_id_stream,
release_sensor_stream)
else:
logger.debug('Not using prediction...')
return (prediction_stream, top_down_segmented_camera_stream,
notify_reading_stream)

Expand Down Expand Up @@ -483,10 +523,11 @@ def add_planning(goal_location, pose_stream, prediction_stream,
:py:class:`erdos.ReadStream`: Stream on which the waypoints are
published.
"""
logger.debug('Using behavior planning...')
trajectory_stream = pylot.operator_creator.add_behavior_planning(
pose_stream, open_drive_stream, global_trajectory_stream,
goal_location)

logger.debug('Using planning...')
waypoints_stream = pylot.operator_creator.add_planning(
pose_stream, prediction_stream, traffic_lights_stream, lanes_stream,
trajectory_stream, open_drive_stream, time_to_decision_stream)
Expand All @@ -513,57 +554,55 @@ def add_control(pose_stream,
published.
"""
if FLAGS.control == 'pid':
logger.debug('Using PID controller...')
control_stream = pylot.operator_creator.add_pid_control(
pose_stream, waypoints_stream)
elif FLAGS.control == 'mpc':
logger.debug('Using MPC controller...')
control_stream = pylot.operator_creator.add_mpc(
pose_stream, waypoints_stream)
elif FLAGS.control in ['simulator_auto_pilot', 'manual']:
# TODO: Hack! We synchronize on a single stream, based on a
# guesestimate of which stream is slowest.
logger.debug('Using the manual control/autopilot...')
stream_to_sync_on = waypoints_stream
if (FLAGS.evaluate_obstacle_detection
and not FLAGS.perfect_obstacle_detection):
# Ensure that the perfect obstacle detector doesn't remain
# behind.
logger.debug('Synchronizing ticking using the perfect detector'
' stream')
stream_to_sync_on = perfect_obstacles_stream
else:
logger.debug('Synchronizing ticking using the waypoints stream')
control_stream = pylot.operator_creator.add_synchronizer(
ground_vehicle_id_stream, stream_to_sync_on)
else:
raise ValueError('Unexpected control {}'.format(FLAGS.control))

if FLAGS.evaluate_control:
logger.debug('Adding control evaluation operator...')
pylot.operator_creator.add_control_evaluation(pose_stream,
waypoints_stream)
return control_stream


def add_evaluation(vehicle_id_stream,
pose_stream,
imu_stream,
pose_stream_for_control=None,
waypoints_stream_for_control=None):
def add_evaluation(vehicle_id_stream, pose_stream, imu_stream):
if FLAGS.evaluation:
# Add the collision sensor.
logger.debug('Adding collision logging sensor...')
collision_stream = pylot.operator_creator.add_collision_sensor(
vehicle_id_stream)

# Add the lane invasion sensor.
logger.debug('Adding lane invasion sensor...')
lane_invasion_stream = pylot.operator_creator.add_lane_invasion_sensor(
vehicle_id_stream)

# Add the traffic light invasion sensor.
logger.debug('Adding traffic light invasion sensor...')
traffic_light_invasion_stream = \
pylot.operator_creator.add_traffic_light_invasion_sensor(
vehicle_id_stream, pose_stream)

# Add the evaluation logger.
logger.debug('Adding overall evaluation operator...')
pylot.operator_creator.add_eval_metric_logging(
collision_stream, lane_invasion_stream,
traffic_light_invasion_stream, imu_stream, pose_stream)

# Add control evaluation logging operator.
if (FLAGS.evaluate_control and pose_stream_for_control
and waypoints_stream_for_control):
pylot.operator_creator.add_control_evaluation(
pose_stream_for_control, waypoints_stream_for_control)
Loading

0 comments on commit 64d0d9d

Please sign in to comment.