Skip to content

Commit

Permalink
Added video, argparse
Browse files Browse the repository at this point in the history
  • Loading branch information
skhadem committed Mar 12, 2019
1 parent ff4f930 commit f28c0ae
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
Kitti/testing
Kitti/training
legacy
eval/video/unused
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ If interested, join the slack workspace where the paper is discussed, issues are
## Introduction
PyTorch implementation for this [paper](https://arxiv.org/abs/1612.00496).

![example](http://soroushkhadem.com/img/2d-top-3d-bottom1.png)
![example-image](http://soroushkhadem.com/img/2d-top-3d-bottom1.png)

At the moment, it takes approximately 0.4s per frame, depending on the number of objects
detected. An improvement will be speed upgrades soon. Here is the current fastest possible:
![example-video](http://soroushkhadem.com/img/3d-bbox-video1.mp4)

## Requirements
- PyTorch
Expand All @@ -19,10 +23,22 @@ cd weights/
This will download the weights I have trained and also the YOLOv3 weights from the
official yolo [site](https://pjreddie.com/darknet/yolo/).

To run in evaluation:
To see options:
```
python Run.py --help
```
python Run.py

Run through all images in default directory (eval/image_2/):
```
python Run.py [--show-yolo]
```

Run through default video:
```
python Run.py --video [--hide-debug]
```


>Note: This script expects images in `./Kitti/testing/image_2/` and corresponding projection matricies
in `./Kitti/testing/calib/`. See [training](#training) for where to download data from.

Expand Down
103 changes: 81 additions & 22 deletions Run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,57 @@
from torch.autograd import Variable
from torchvision.models import vgg

import argparse

def plot_regressed_3d_bbox(img, truth_img, cam_to_img, box_2d, dimensions, alpha, theta_ray):
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


parser = argparse.ArgumentParser()

parser.add_argument("--image-dir", default="eval/image_2/",
help="Relative path to the directory containing images to detect. Default \
is eval/image_2/")

# TODO: support multiple cal matrix input types
parser.add_argument("--cal-dir", default="camera_cal/",
help="Relative path to the directory containing camera calibration form KITTI. \
Default is camera_cal/")

parser.add_argument("--video", action="store_true",
help="Weather or not to advance frame-by-frame as fast as possible. \
By default, this will pull images from ./eval/video")

parser.add_argument("--show-yolo", action="store_true",
help="Show the 2D BoundingBox detecions on a separate image")

parser.add_argument("--hide-debug", action="store_true",
help="Supress the printing of each 3d location")


def plot_regressed_3d_bbox(img, cam_to_img, box_2d, dimensions, alpha, theta_ray, img_2d=None):

# the math! returns X, the corners used for constraint
location, X = calc_location(dimensions, cam_to_img, box_2d, alpha, theta_ray)

orient = alpha + theta_ray

plot_2d_box(truth_img, box_2d)
if img_2d is not None:
plot_2d_box(img_2d, box_2d)

plot_3d_box(img, cam_to_img, orient, dimensions, location) # 3d boxes

return location

def main():

FLAGS = parser.parse_args()

# load torch
weights_path = os.path.abspath(os.path.dirname(__file__)) + '/weights'
model_lst = [x for x in sorted(os.listdir(weights_path)) if x.endswith('.pkl')]
Expand All @@ -60,19 +96,30 @@ def main():

averages = ClassAverages.ClassAverages()

# TODO: clean up how this is done
# TODO: clean up how this is done. flag?
angle_bins = generate_bins(2)

img_path = os.path.abspath(os.path.dirname(__file__)) + '/Kitti/testing/image_2/'
image_dir = FLAGS.image_dir
cal_dir = FLAGS.cal_dir
if FLAGS.video:
if FLAGS.image_dir == "eval/image_2/" and FLAGS.cal_dir == "camera_cal/":
image_dir = "eval/video/2011_09_26/image_2/"
cal_dir = "eval/video/2011_09_26/"

# using P from each frame
calib_path = os.path.abspath(os.path.dirname(__file__)) + '/Kitti/testing/calib/'

img_path = os.path.abspath(os.path.dirname(__file__)) + "/" + image_dir
# using P_rect from global calibration file
# calib_path = os.path.abspath(os.path.dirname(__file__)) + '/camera_cal/'
# calib_file = calib_path + "calib_cam_to_cam.txt"
calib_path = os.path.abspath(os.path.dirname(__file__)) + "/" + cal_dir
calib_file = calib_path + "calib_cam_to_cam.txt"

ids = [x.split('.')[0] for x in sorted(os.listdir(img_path))]
# using P from each frame
# calib_path = os.path.abspath(os.path.dirname(__file__)) + '/Kitti/testing/calib/'

try:
ids = [x.split('.')[0] for x in sorted(os.listdir(img_path))]
except:
print("\nError: no images in %s"%img_path)
exit()

for id in ids:

Expand All @@ -81,7 +128,7 @@ def main():
img_file = img_path + id + ".png"

# P for each frame
calib_file = calib_path + id + ".txt"
# calib_file = calib_path + id + ".txt"

truth_img = cv2.imread(img_file)
img = np.copy(truth_img)
Expand Down Expand Up @@ -125,18 +172,30 @@ def main():
alpha += angle_bins[argmax]
alpha -= np.pi

location = plot_regressed_3d_bbox(img, truth_img, proj_matrix, box_2d, dim, alpha, theta_ray)

print('Estimated pose: %s'%location)

numpy_vertical = np.concatenate((truth_img, img), axis=0)
cv2.imshow('SPACE for next image, any other key to exit', numpy_vertical)

print("\n")
print('Got %s poses in %.3f seconds'%(len(detections), time.time() - start_time))
print('-------------')
if cv2.waitKey(0) != 32: # space bar
exit()
if FLAGS.show_yolo:
location = plot_regressed_3d_bbox(img, proj_matrix, box_2d, dim, alpha, theta_ray, truth_img)
else:
location = plot_regressed_3d_bbox(img, proj_matrix, box_2d, dim, alpha, theta_ray)

if not FLAGS.hide_debug:
print('Estimated pose: %s'%location)

if FLAGS.show_yolo:
numpy_vertical = np.concatenate((truth_img, img), axis=0)
cv2.imshow('SPACE for next image, any other key to exit', numpy_vertical)
else:
cv2.imshow('3D detections', img)

if not FLAGS.hide_debug:
print("\n")
print('Got %s poses in %.3f seconds'%(len(detections), time.time() - start_time))
print('-------------')

if FLAGS.video:
cv2.waitKey(1)
else:
if cv2.waitKey(0) != 32: # space bar
exit()

if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion library/File.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def get_P(cab_f):
return_matrix = cam_P.reshape((3,4))
return return_matrix

file_not_found(cab_f)
# try other type of file
return get_calibration_cam_to_image

def get_R0(cab_f):
for line in open(cab_f):
Expand Down
4 changes: 2 additions & 2 deletions torch_lib/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ class DetectedObject:
def __init__(self, img, detection_class, box_2d, proj_matrix, label=None):

if isinstance(proj_matrix, str): # filename
# proj_matrix = get_P(proj_matrix)
proj_matrix = get_calibration_cam_to_image(proj_matrix)
proj_matrix = get_P(proj_matrix)
# proj_matrix = get_calibration_cam_to_image(proj_matrix)

self.proj_matrix = proj_matrix
self.theta_ray = self.calc_theta_ray(img, box_2d, proj_matrix)
Expand Down

0 comments on commit f28c0ae

Please sign in to comment.