Skip to content

Commit

Permalink
find the right place where to send the tensor to cpu to gain addition…
Browse files Browse the repository at this point in the history
…al speedup on CUDA/CPU transfer
  • Loading branch information
Michal Conos committed Sep 13, 2024
1 parent 485fac6 commit b5628fe
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 27 deletions.
2 changes: 1 addition & 1 deletion benches/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn bench_detection_e2e(c: &mut Criterion) {
let image = Image::new(black_box("images/bus.jpg"), black_box((640, 640)));
let yolo = YoloV8ObjectDetection::new();
let result = yolo.predict(black_box(&image), black_box(0.25), black_box(0.7));
black_box(result.postprocess())
black_box(result.postprocess().0)
})
});
}
Expand Down
29 changes: 25 additions & 4 deletions examples/predict/main.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,43 @@
use std::time::Instant;

use tch::{TchError, Tensor};
use yolo_v8::{image::Image, YoloV8Classifier, YoloV8ObjectDetection, YoloV8Segmentation};

fn object_detection(path: &str) {
// Load image to perform object detection, note that YOLOv8 resolution must match
// scaling width and height here
let mut timings = vec![];
let start = Instant::now();
let mut image = Image::new(path, YoloV8ObjectDetection::input_dimension());
timings.push(("load image", start.elapsed()));

let start = Instant::now();
// Load exported torchscript for object detection
let yolo = YoloV8ObjectDetection::new();
let yolo = YoloV8ObjectDetection::new().post_process_on_cpu();
timings.push(("load model", start.elapsed()));

let start = Instant::now();
// Predict with non-max-suppression in the end
let bboxes = yolo.predict(&image, 0.25, 0.7).postprocess();
let prediction = yolo.predict(&image, 0.25, 0.7);
timings.push(("prediction", start.elapsed()));

let start = Instant::now();
// extract bboxes from prediction in post-processing
let bboxes = prediction.postprocess();
timings.push(("post-process", start.elapsed()));
println!("bboxes={:?}", bboxes);

let start = Instant::now();
// Draw rectangles around detected objects
image.draw_rectangle(&bboxes);
image.draw_rectangle(&bboxes.0);
timings.push(("draw rectangles", start.elapsed()));

let start = Instant::now();
// Finally save the result
image.save("images/result2.jpg");
timings.push(("save result", start.elapsed()));

println!("timings:{:?}", timings);
}

fn image_classification(path: &str) {
Expand Down Expand Up @@ -58,7 +79,7 @@ fn image_segmentation(path: &str) {

// YOLOv8n for object detection in image
fn main() -> Result<(), TchError> {
object_detection("images/bus.jpg");
object_detection("images/frame.png");
image_classification("images/bus.jpg");
image_segmentation("images/test.jpg");
Ok(())
Expand Down
27 changes: 19 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
pub mod image;
pub mod utils;

use std::time::Duration;

use image::{Image, ImageCHW};
use tch::{IValue, Tensor};
use utils::{get_model, DetectionTools, SegmentationTools};
Expand Down Expand Up @@ -72,16 +74,18 @@ pub struct ObjectDetectionPrediction {
scaled_image_dim: ImageCHW,
conf_threshold: f64,
iou_threshold: f64,
run_on_cpu: bool,
}

impl ObjectDetectionPrediction {
pub fn postprocess(&self) -> Vec<BBox> {
pub fn postprocess(&self) -> (Vec<BBox>, Vec<(&str, Duration)>) {
DetectionTools::non_max_suppression(
self.image_dim,
self.scaled_image_dim,
&self.pred,
self.conf_threshold,
self.iou_threshold,
self.run_on_cpu,
)
}
}
Expand Down Expand Up @@ -162,23 +166,31 @@ impl YoloV8Classifier {

pub struct YoloV8ObjectDetection {
yolo: YOLOv8,
post_process_on_cpu: bool,
}

impl YoloV8ObjectDetection {
pub fn with_model(model_type: YOLOModel) -> Self {
Self {
yolo: YOLOv8::new(&get_model(model_type, utils::YOLOSpec::ObjectDetection))
.expect("can't load model"),
post_process_on_cpu: false,
}
}

pub fn post_process_on_cpu(mut self) -> Self {
self.post_process_on_cpu = true;
self
}

pub fn new() -> Self {
Self {
yolo: YOLOv8::new(&get_model(
YOLOModel::Nano,
utils::YOLOSpec::ObjectDetection,
))
.expect("can't load model"),
post_process_on_cpu: false,
}
}

Expand All @@ -200,6 +212,7 @@ impl YoloV8ObjectDetection {
pred,
conf_threshold,
iou_threshold,
run_on_cpu: self.post_process_on_cpu,
}
}
}
Expand Down Expand Up @@ -282,12 +295,10 @@ impl YOLOv8 {
.to_device(self.device)
.g_div_scalar(255.);

let pred = self
.model
.forward_ts(&[img])
.unwrap()
.to_device(tch::Device::Cpu);
// .to_device(self.device);
let pred = self.model.forward_ts(&[img]).unwrap();

// .to_device(tch::Device::Cpu)
// .to_device(self.device)

pred
}
Expand All @@ -309,7 +320,7 @@ mod test {
fn test_detection() {
let image = Image::new("images/bus.jpg", YoloV8ObjectDetection::input_dimension());
let yolo = YoloV8ObjectDetection::new();
let detection = yolo.predict(&image, 0.25, 0.7).postprocess();
let detection = yolo.predict(&image, 0.25, 0.7).postprocess().0;
println!("detection={:?}", detection);
assert_eq!(3, detection.len());
bbox_eq(
Expand Down
38 changes: 24 additions & 14 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Instant;
use std::time::{Duration, Instant};

use tch::{Device, IValue, Kind, Tensor};

Expand Down Expand Up @@ -198,41 +198,52 @@ impl DetectionTools {
prediction: &tch::Tensor,
conf_thresh: f64,
iou_thresh: f64,
) -> Vec<BBox> {
run_on_cpu: bool, // in certain cases it's faster to post-process with cpu
) -> (Vec<BBox>, Vec<(&str, Duration)>) {
let mut timings = Vec::new();
let start = Instant::now();
let prediction = prediction.get(0);
let prediction = prediction.transpose(1, 0);
let (anchors, classes_no) = prediction.size2().unwrap();
let prediction = if run_on_cpu {
prediction.transpose(1, 0).to_device(tch::Device::Cpu)
} else {
prediction.transpose(1, 0)
};
let (_anchors, classes_no) = prediction.size2().unwrap();

let initial_w = image_dim.2 as f64;
let initial_h = image_dim.1 as f64;
let w_ratio = initial_w / scaled_image_dim.2 as f64;
let h_ratio = initial_h / scaled_image_dim.1 as f64;

let nclasses = (classes_no - 4) as usize;
// println!("classes_no={classes_no}, anchors={anchors}");
let sliced_predictions = prediction.slice(1, 4, 84, 1);

// println!("pred={:?}", prediction);
let sliced_predictions = prediction.slice(1, 4, 84, 1); // 1 for the dimension index
// println!("sliced_predictions={:?}", sliced_predictions);
timings.push(("prolog", start.elapsed()));
let start = Instant::now();

// Compute the maximum along dimension 1 (across the specified columns)
let max_values = sliced_predictions.amax(1, false);
// println!("max_values={:?}", max_values);
// println!("max_values={}", max_values);

timings.push(("max-value", start.elapsed()));

let start = Instant::now();
// Create a boolean mask where max values are greater than the confidence threshold
let xc = max_values.gt(conf_thresh);
timings.push(("conf-match", start.elapsed()));
let start = Instant::now();
let t = xc.nonzero().view(-1);
// println!("tensor device: non-zero: {:?}", t.device());
timings.push(("non-zero", start.elapsed()));
let start = Instant::now();

// println!("t={:?} t={t}", t);
let indexes = Vec::<i64>::try_from(t).expect("can't get indexes where confidence is met");
// println!("indexes={:?}", indexes);
// println!("xc={:?}", xc);

let mut bboxes: Vec<Vec<BBox>> = (0..nclasses).map(|_| vec![]).collect();

timings.push(("prolog", start.elapsed()));
timings.push(("index-to-cpu", start.elapsed()));
let start = Instant::now();

for index in indexes {
Expand Down Expand Up @@ -328,9 +339,7 @@ impl DetectionTools {
}
}
timings.push(("epilog", start.elapsed()));
// println!("timings={:?}", timings);

return result;
return (result, timings);
}

fn iou(b1: &BBox, b2: &BBox) -> f64 {
Expand All @@ -344,6 +353,7 @@ impl DetectionTools {
i_area / (b1_area + b2_area - i_area)
}
}

// global image preprocessing:
// 1) resize, keep aspect ratio
// 2) padding to square tensor
Expand Down

0 comments on commit b5628fe

Please sign in to comment.