Skip to content

Commit

Permalink
improve the detection post-processing speed in order of magnitude (40…
Browse files Browse the repository at this point in the history
…ms->121.519µs)
  • Loading branch information
Michal Conos committed Sep 13, 2024
1 parent b65fa7d commit 485fac6
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 44 deletions.
6 changes: 6 additions & 0 deletions env
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export LIBTORCH=$(pwd)/libtorch/
export LIBTORCH_INCLUDE=$(pwd)/libtorch/
export LIBTORCH_LIB=$(pwd)/libtorch/

export LD_LIBRARY_PATH="$LIBTORCH/lib/:$LD_LIBRARY_PATH"

122 changes: 78 additions & 44 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::Instant;

use tch::{Device, IValue, Kind, Tensor};

use crate::{image::ImageCHW, BBox, SegBBox, SegmentationResult, YOLOModel};
Expand Down Expand Up @@ -197,6 +199,8 @@ impl DetectionTools {
conf_thresh: f64,
iou_thresh: f64,
) -> Vec<BBox> {
let mut timings = Vec::new();
let start = Instant::now();
let prediction = prediction.get(0);
let prediction = prediction.transpose(1, 0);
let (anchors, classes_no) = prediction.size2().unwrap();
Expand All @@ -209,62 +213,88 @@ impl DetectionTools {
let nclasses = (classes_no - 4) as usize;
// println!("classes_no={classes_no}, anchors={anchors}");

// println!("pred={:?}", prediction);
let sliced_predictions = prediction.slice(1, 4, 84, 1); // 1 for the dimension index
// println!("sliced_predictions={:?}", sliced_predictions);

// Compute the maximum along dimension 1 (across the specified columns)
let max_values = sliced_predictions.amax(1, false);
// println!("max_values={:?}", max_values);
// println!("max_values={}", max_values);

// Create a boolean mask where max values are greater than the confidence threshold
let xc = max_values.gt(conf_thresh);
let t = xc.nonzero().view(-1);
// println!("t={:?} t={t}", t);
let indexes = Vec::<i64>::try_from(t).expect("can't get indexes where confidence is met");
// println!("indexes={:?}", indexes);
// println!("xc={:?}", xc);

let mut bboxes: Vec<Vec<BBox>> = (0..nclasses).map(|_| vec![]).collect();

for index in 0..anchors {
timings.push(("prolog", start.elapsed()));
let start = Instant::now();

for index in indexes {
// println!("has confidence: {index}");
let pred = Vec::<f64>::try_from(prediction.get(index)).expect("wrong type of tensor");

// println!("index={index}, pred={}", pred.len());

let mut max_conf = 0.0;
let mut idx = 0;
for i in 4..classes_no as usize {
let confidence = pred[i];
if confidence > conf_thresh {
let class_index = i - 4;
// println!(
// "confidence={confidence}, class_index={class_index} class_name={}",
// CLASSES[class_index]
// );

let (_, orig_h, orig_w) = image_dim;
let (_, sh, sw) = scaled_image_dim;
let cx = sw as f64 / 2.0;
let cy = sh as f64 / 2.0;
let mut dx = pred[0] - cx;
let mut dy = pred[1] - cy;
let mut w = pred[2];
let mut h = pred[3];

let aspect = orig_w as f64 / orig_h as f64;

if orig_w > orig_h {
dy *= aspect;
h *= aspect;
} else {
dx /= aspect;
w /= aspect;
}

let x = cx + dx;
let y = cy + dy;

let xmin = ((x - w / 2.) * w_ratio).clamp(0.0, initial_w - 1.0);
let ymin = ((y - h / 2.) * h_ratio).clamp(0.0, initial_h - 1.0);
let xmax = ((x + w / 2.) * w_ratio).clamp(0.0, initial_w - 1.0);
let ymax = ((y + h / 2.) * h_ratio).clamp(0.0, initial_h - 1.0);
if confidence > max_conf {
max_conf = confidence;
idx = i;
}
}

let bbox = BBox {
xmin,
ymin,
xmax,
ymax,
conf: confidence,
cls: class_index,
name: crate::classes::DETECT_CLASSES[class_index],
};
bboxes[class_index].push(bbox)
if max_conf > conf_thresh && idx >= 4 {
let class_index = idx - 4;

let (_, orig_h, orig_w) = image_dim;
let (_, sh, sw) = scaled_image_dim;
let cx = sw as f64 / 2.0;
let cy = sh as f64 / 2.0;
let mut dx = pred[0] - cx;
let mut dy = pred[1] - cy;
let mut w = pred[2];
let mut h = pred[3];

let aspect = orig_w as f64 / orig_h as f64;

if orig_w > orig_h {
dy *= aspect;
h *= aspect;
} else {
dx /= aspect;
w /= aspect;
}

let x = cx + dx;
let y = cy + dy;

let xmin = ((x - w / 2.) * w_ratio).clamp(0.0, initial_w - 1.0);
let ymin = ((y - h / 2.) * h_ratio).clamp(0.0, initial_h - 1.0);
let xmax = ((x + w / 2.) * w_ratio).clamp(0.0, initial_w - 1.0);
let ymax = ((y + h / 2.) * h_ratio).clamp(0.0, initial_h - 1.0);

let bbox = BBox {
xmin,
ymin,
xmax,
ymax,
conf: max_conf,
cls: class_index,
name: crate::classes::DETECT_CLASSES[class_index],
};
bboxes[class_index].push(bbox)
}
}
timings.push(("bbox-conf", start.elapsed()));
let start = Instant::now();

for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.conf.partial_cmp(&b1.conf).unwrap());
Expand All @@ -287,14 +317,18 @@ impl DetectionTools {
}
bboxes_for_class.truncate(current_index);
}
timings.push(("iou-processing", start.elapsed()));

let start = Instant::now();
let mut result = vec![];

for bboxes_for_class in bboxes.iter() {
for bbox in bboxes_for_class.iter() {
result.push(*bbox);
}
}
timings.push(("epilog", start.elapsed()));
// println!("timings={:?}", timings);

return result;
}
Expand Down

0 comments on commit 485fac6

Please sign in to comment.