Skip to content

Commit

Permalink
update ocsort with byte
Browse files Browse the repository at this point in the history
  • Loading branch information
HanGuangXin committed Apr 27, 2022
1 parent 05bed58 commit bdfbc98
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tools/demo_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def image_demo(predictor, vis_folder, current_time, args):
else:
files = [args.path]
files.sort()
tracker = OCSort(det_thresh=args.track_thresh, iou_threshold=args.iou_thresh)
tracker = OCSort(det_thresh=args.track_thresh, iou_threshold=args.iou_thresh, use_byte=args.use_byte)
timer = Timer()
results = []

Expand Down Expand Up @@ -161,7 +161,7 @@ def imageflow_demo(predictor, vis_folder, current_time, args):
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
tracker = OCSort(det_thresh=args.track_thresh, iou_threshold=args.iou_thresh)
tracker = OCSort(det_thresh=args.track_thresh, iou_threshold=args.iou_thresh, use_byte=args.use_byte)
timer = Timer()
frame_id = 0
results = []
Expand Down
29 changes: 27 additions & 2 deletions trackers/ocsort_tracker/ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_state(self):

class OCSort(object):
def __init__(self, det_thresh, max_age=30, min_hits=3,
iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2):
iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, use_byte=False):
"""
Sets key parameters for SORT
"""
Expand All @@ -187,6 +187,7 @@ def __init__(self, det_thresh, max_age=30, min_hits=3,
self.delta_t = delta_t
self.asso_func = ASSO_FUNCS[asso_func]
self.inertia = inertia
self.use_byte = use_byte
KalmanBoxTracker.count = 0

def update(self, output_results, img_info, img_size):
Expand All @@ -213,7 +214,10 @@ def update(self, output_results, img_info, img_size):
scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w))
bboxes /= scale
dets = np.concatenate((bboxes, np.expand_dims(scores, axis=-1)), axis=1)

inds_low = scores > 0.1
inds_high = scores < self.det_thresh
inds_second = np.logical_and(inds_low, inds_high) # self.det_thresh > score > 0.1, for second matching
dets_second = dets[inds_second] # detections for second matching
remain_inds = scores > self.det_thresh
dets = dets[remain_inds]

Expand Down Expand Up @@ -247,6 +251,27 @@ def update(self, output_results, img_info, img_size):
"""
Second round of associaton by OCR
"""
# BYTE association
if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0:
u_trks = trks[unmatched_trks]
iou_left = self.asso_func(dets_second, u_trks) # iou between low score detections and unmatched tracks
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
matched_indices = linear_assignment(-iou_left)
to_remove_trk_indices = []
for m in matched_indices:
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
continue
self.trackers[trk_ind].update(dets_second[det_ind, :])
to_remove_trk_indices.append(trk_ind)
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))

if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
Expand Down
1 change: 1 addition & 0 deletions utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def make_parser():
parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
parser.add_argument("--public", action="store_true", help="use public detection")
parser.add_argument('--asso', default="iou", help="similarity function: iou/giou/diou/ciou/ctdis")
parser.add_argument("--use_byte", dest="use_byte", default=False, action="store_true", help="use byte in tracking.")

# for kitti/bdd100k inference with public detections
parser.add_argument('--raw_results_path', type=str, default="exps/permatrack_kitti_test/",
Expand Down
5 changes: 3 additions & 2 deletions yolox/evaluators/mot_evaluator_dance.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def evaluate_ocsort(
model = model_trt

tracker = OCSort(det_thresh = self.args.track_thresh, iou_threshold=self.args.iou_thresh,
asso_func=self.args.asso, delta_t=self.args.deltat, inertia=self.args.inertia)
asso_func=self.args.asso, delta_t=self.args.deltat, inertia=self.args.inertia, use_byte=self.args.use_byte)

detections = dict()

Expand All @@ -240,7 +240,7 @@ def evaluate_ocsort(

if frame_id == 1:
tracker = OCSort(det_thresh = self.args.track_thresh, iou_threshold=self.args.iou_thresh,
asso_func=self.args.asso, delta_t=self.args.deltat, inertia=self.args.inertia)
asso_func=self.args.asso, delta_t=self.args.deltat, inertia=self.args.inertia, use_byte=self.args.use_byte)
if len(results) != 0:
result_filename = os.path.join(result_folder, '{}.txt'.format(video_names[video_id - 1]))
write_results_no_score(result_filename, results)
Expand Down Expand Up @@ -352,6 +352,7 @@ def convert_to_coco_format(self, outputs, info_imgs, ids):
return data_list



def evaluate_prediction(self, data_dict, statistics):
if not is_main_process():
return 0, 0, None
Expand Down

0 comments on commit bdfbc98

Please sign in to comment.