Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing multi object tracking - adding Joint Detection and Embedding Tracker (JDETracker) #92

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
100 changes: 100 additions & 0 deletions projects/multi_object_tracking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Multi Object Tracking for PyTorchVideo

The project demonstrates the use of multi object tracking for PuTorchVideo.
Currently the project contains JDETracker as the multiobject tracker. Very soon,more trackers can
be added to this project.


## JDETracker
**Joint Detection and Embedding (JDE) Tracker** was introduced in the paper 'Towards Real-Time Multi-Object Tracking' (https://arxiv.org/abs/1909.12605).
This tracker can work with any model which can provide the following two inputs:

(i) pred_dets: Detection results of the image i.e. x1, y1, x2, y2, object_conf (The detections should/could be passed through NMS or a similar technique followed by scaling the coordinates to the original image size - _This way allowing more flexibility for the tracker_)

(ii) pred_embs: Embedding results of the image.

#### Folder structure
.
├── mot # The main multi-object tracking library which can be later on added to PyTorchVideo
│ ├── matching
│ | ├── ..
| | ├── jde_matching.py
│ ├── motion
| ├── ..
| | ├── kalman_filter.py # File related to Kalman Filter
│ ├── tracker
| ├── ..
| | ├── base_jde.py # Contains TrackState and Strack classes which are the core base classes for JDE Tracking
| | ├── jde_tracker.py # Contains the class to be used using JDE Tracking
├── demo_jde_tracker.py # Sample demo file for using a detector and integrating it with JDE tracker to obtain tracking results.
├── detector_utils.py # utility file (including the model definition) related to the detector used in the paper
├── tests # test files related to JDETracker
├── jde_dets.pt # file containing detections, embedding for running test file
└── README.md

#### Set up

In order to set up your system for testing the JDETracker with integrated with a detector - these are the required steps:

(1) Set the current working directory to the multi_object_tracking folder:
```python
cd pytorchvideo/projects/multi_object_tracking
```

(1) Download weights:
Download the weights from this location https://drive.google.com/open?id=1nlnuYfGNuHWZztQHXwVZSL_FvfE551pA and copy it in the 'weights' folder
```python
mkdir weights
cd weights
cp path/to/weights . #copy weights file to this folder
cd ..
```

(2) Download any sample video:
For this current demo we will take the MOT16-03.mp4 video which is available on the link: https://drive.google.com/file/d/1254q3ruzBzgn4LUejDVsCtT05SIEieQg/view?usp=sharing
This video is stored at videos/MOT16-03.mp4

(3) Set up the requirements:
* [Pytorch](https://pytorch.org) >= 1.2.0
* python-opencv
* cython-bbox (`pip install cython_bbox`)
* ffmpeg (used for creating a tracking result video)

#### Running the demo

```python
python demo_jde_tracker.py
--input-video videos/MOT16-03.mp4
--weights weights/jde.1088x608.uncertainty.pt
--cfg yolov3_1088x608.cfg
--output-root .
```

* The results will be stored as follows:

(i) Individual frames with their bounding boxes and track ids will be stored in the _frame_ folder.

(ii) After the entire video is processed, a _result.mp4_ file will be created for the video of the entire run.

#### Running the JDE Tracker with any other detector
As stated above, The JDE tracker can be integrated with any model which can output both detections and embeddings.
Also refer to the test file in tests folder for futher reference
```python

# Step 1: Create instance of JDETracker
from mot.tracker import JDETracker
tracker = JDETracker

# Step 2: (In a loop)
# Run image/video frame through the detector, obtain detections and embeddings
# Filter them through NMS and scale them back to the original image size.
# Lastly pass them to the update function of the JDETracker
# .....
online_targets = mot.tracker import JDETracker
# ....
```

##### References
[1] Zhongdao Wang, Liang Zheng, Yixuan Liu, Yali Li, Shengjin Wang, Towards Real-Time Multi-Object Tracking, ECCV 2020

[2] https://github.com/Zhongdao/Towards-Realtime-MOT
96 changes: 96 additions & 0 deletions projects/multi_object_tracking/demo_jde_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import torch
import cv2

from detector_utils import *

from mot.tracker import JDETracker

class DemoJDETracker():
def __init__(self, opt):
self.opt = opt
cfg_dict = parse_model_cfg(opt.cfg)
self.opt.img_size = [int(cfg_dict[0]['width']), int(cfg_dict[0]['height'])]

self.result_root = self.opt.output_root if opt.output_root != '' else '.'
os.makedirs(self.result_root, exist_ok=True)
self.frame_dir = os.path.join(self.result_root, 'frame')
os.makedirs(self.frame_dir, exist_ok=True)

# set dataloader
self.dataloader = LoadVideo(opt.input_video, opt.img_size)

# load detector model
self.model = Darknet(opt.cfg, nID=14455)
self.model.load_state_dict(torch.load(opt.weights, map_location='cpu')['model'], strict=False)

if torch.cuda.is_available():
self.model.cuda().eval()
else:
self.model.eval()

print("Model load complete")

# initialise JDE Tracker
self.tracker = JDETracker()

def track_video(self):
results = []
frame_id = 0

for path, img, img0 in self.dataloader:
im_blob = torch.from_numpy(img).cuda().unsqueeze(0)

with torch.no_grad():
pred_model = self.model(im_blob)

if len(pred_model > 0):
# perform additional steps before pushing into the tracker
# such as threshold filtering, NMS and scaling coordinates
pred_model = pred_model[pred_model[:, :, 4] > self.opt.conf_thres]
pred_model = non_max_suppression(pred_model.unsqueeze(0), self.opt.conf_thres, self.opt.nms_thres)[0].cpu()
scale_coords(self.opt.img_size, pred_model[:, :4], img0.shape).round()

# split the pred_model into two parts -
# pred_dets giving detection results of image - i.e. (batch_id, x1, y1, x2, y2, object_conf)
# pred_embs giving Embedding results of the image
pred_dets, pred_embs = pred_model[:, :5], pred_model[:, 6:]
online_targets = self.tracker.update(pred_dets, pred_embs)

# saving the results for display
online_tlwhs = []
online_ids = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
# save results
results.append((frame_id + 1, online_tlwhs, online_ids))
online_im = plot_tracking(img0, online_tlwhs, online_ids, frame_id=frame_id)
cv2.imwrite(os.path.join(self.frame_dir, '{:05d}.jpg'.format(frame_id)), online_im)

frame_id += 1

# save results as video after the loop
output_video_path = os.path.join(self.result_root, 'result.mp4')
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(self.frame_dir, output_video_path)
os.system(cmd_str)

if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='demo.py')
parser.add_argument('--cfg', type=str, default='cfg/yolov3_1088x608.cfg', help='cfg file path')
parser.add_argument('--weights', type=str, default='weights/latest.pt', help='path to weights file')
parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.4, help='iou threshold for non-maximum suppression')
parser.add_argument('--min-box-area', type=float, default=200, help='filter out tiny boxes')
parser.add_argument('--input-video', type=str, help='path to the input video')
parser.add_argument('--output-format', type=str, default='video', choices=['video', 'text'], help='Expected output format. Video or text.')
parser.add_argument('--output-root', type=str, default='results', help='expected output root path')
opt = parser.parse_args()
print(opt, end='\n\n')

jde_obj = DemoJDETracker(opt)
jde_obj.track_video()
Loading