Skip to content

Commit

Permalink
add benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Conos committed Sep 2, 2024
1 parent e5a5022 commit e5f068e
Show file tree
Hide file tree
Showing 8 changed files with 511 additions and 325 deletions.
12 changes: 12 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,15 @@ criterion = { version = "0.5.1", features = ["html_reports"] }
name = "preprocess"
harness = false

[[bench]]
name = "postprocess"
harness = false

[[bench]]
name = "prediction"
harness = false

[[bench]]
name = "e2e"
harness = false

27 changes: 27 additions & 0 deletions benches/e2e.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use yolo_v8::{image::Image, YoloV8ObjectDetection, YoloV8Segmentation};

fn bench_segmentation_e2e(c: &mut Criterion) {
c.bench_function("bench_segmentation_e2e", |b| {
b.iter(|| {
let image = Image::new(black_box("images/bus.jpg"), black_box((640, 640)));
let yolo = YoloV8Segmentation::new();
let result = yolo.predict(black_box(&image), black_box(0.25), black_box(0.7));
black_box(result.postprocess())
})
});
}

fn bench_detection_e2e(c: &mut Criterion) {
c.bench_function("bench_detection_e2e", |b| {
b.iter(|| {
let image = Image::new(black_box("images/bus.jpg"), black_box((640, 640)));
let yolo = YoloV8ObjectDetection::new();
let result = yolo.predict(black_box(&image), black_box(0.25), black_box(0.7));
black_box(result.postprocess())
})
});
}

criterion_group!(benches, bench_segmentation_e2e, bench_detection_e2e);
criterion_main!(benches);
27 changes: 27 additions & 0 deletions benches/postprocess.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use yolo_v8::{image::Image, YoloV8ObjectDetection, YoloV8Segmentation};

fn bench_segmentation_postprocess(c: &mut Criterion) {
let image = Image::new("images/bus.jpg", (640, 640));
let yolo = YoloV8Segmentation::new();
let result = yolo.predict(&image, 0.25, 0.7);
c.bench_function("bench_segmentation_postprocess", |b| {
b.iter(|| black_box(result.postprocess()))
});
}

fn bench_detection_postprocess(c: &mut Criterion) {
let image = Image::new("images/bus.jpg", (640, 640));
let yolo = YoloV8ObjectDetection::new();
let result = yolo.predict(&image, 0.25, 0.7);
c.bench_function("bench_detection_postprocess", |b| {
b.iter(|| black_box(result.postprocess()))
});
}

criterion_group!(
benches,
bench_segmentation_postprocess,
bench_detection_postprocess
);
criterion_main!(benches);
25 changes: 25 additions & 0 deletions benches/prediction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use yolo_v8::{image::Image, YoloV8ObjectDetection, YoloV8Segmentation};

fn bench_segmentation_prediction(c: &mut Criterion) {
let image = Image::new("images/bus.jpg", (640, 640));
let yolo = YoloV8Segmentation::new();
c.bench_function("bench_segmentation_prediction", |b| {
b.iter(|| black_box(yolo.predict(black_box(&image), black_box(0.25), black_box(0.7))))
});
}

fn bench_detection_prediction(c: &mut Criterion) {
let image = Image::new("images/bus.jpg", (640, 640));
let yolo = YoloV8ObjectDetection::new();
c.bench_function("bench_detection_prediction", |b| {
b.iter(|| black_box(yolo.predict(black_box(&image), black_box(0.25), black_box(0.7))))
});
}

criterion_group!(
benches,
bench_segmentation_prediction,
bench_detection_prediction
);
criterion_main!(benches);
6 changes: 3 additions & 3 deletions examples/predict/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use tch::{TchError, Tensor};
use yolo_v8::{Image, YoloV8Classifier, YoloV8ObjectDetection, YoloV8Segmentation};
use yolo_v8::{image::Image, YoloV8Classifier, YoloV8ObjectDetection, YoloV8Segmentation};

fn object_detection(path: &str) {
// Load image to perform object detection, note that YOLOv8 resolution must match
Expand All @@ -10,7 +10,7 @@ fn object_detection(path: &str) {
let yolo = YoloV8ObjectDetection::new();

// Predict with non-max-suppression in the end
let bboxes = yolo.predict(&image, 0.25, 0.7);
let bboxes = yolo.predict(&image, 0.25, 0.7).postprocess();
println!("bboxes={:?}", bboxes);

// Draw rectangles around detected objects
Expand All @@ -36,7 +36,7 @@ fn image_segmentation(path: &str) {
// Load exported torchscript for object detection
let yolo = YoloV8Segmentation::new();

let segmentation = yolo.predict(&image, 0.25, 0.7);
let segmentation = yolo.predict(&image, 0.25, 0.7).postprocess();
println!("segmentation={:?}", segmentation);
let mut mask_no = 0;
for seg in segmentation {
Expand Down
78 changes: 78 additions & 0 deletions src/image.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use tch::Tensor;

use crate::{utils, BBox};

// Image channels, height and width
pub type ImageCHW = (i64, i64, i64);

pub struct Image {
width: i64,
height: i64,
pub(crate) image: Tensor,
pub(crate) scaled_image: Tensor,
pub(crate) image_dim: ImageCHW,
pub(crate) scaled_image_dim: ImageCHW,
}

impl Image {
fn from_tensor(image: Tensor, dimension: (i64, i64)) -> Self {
let width = dimension.0;
let height = dimension.1;

let scaled_image = utils::preprocess(&image, dimension.0);
let image_dim = image.size3().unwrap();
let scaled_image_dim = scaled_image.size3().unwrap();
Self {
width,
height,
image,
scaled_image,
image_dim,
scaled_image_dim,
}
}

pub fn from_slice(
slice: &[u8],
orig_width: i64,
orig_height: i64,
dimension: (i64, i64),
) -> Self {
let image = Tensor::from_slice(slice).view((3, orig_height, orig_width));
Self::from_tensor(image, dimension)
}

pub fn new(path: &str, dimension: (i64, i64)) -> Self {
let image = tch::vision::image::load(path).expect("can't load image");
Self::from_tensor(image, dimension)
}

fn draw_line(t: &mut tch::Tensor, x1: i64, x2: i64, y1: i64, y2: i64) {
let color = Tensor::from_slice(&[255., 255., 0.]).view([3, 1, 1]);
t.narrow(2, x1, x2 - x1)
.narrow(1, y1, y2 - y1)
.copy_(&color)
}

pub fn draw_rectangle(&mut self, bboxes: &Vec<BBox>) {
let image = &mut self.image;
let (_, initial_h, initial_w) = image.size3().expect("can't get image size");
let w_ratio = initial_w as f64 / self.width as f64;
let h_ratio = initial_h as f64 / self.height as f64;

for bbox in bboxes.iter() {
let xmin = ((bbox.xmin * w_ratio) as i64).clamp(0, initial_w - 1);
let ymin = ((bbox.ymin * h_ratio) as i64).clamp(0, initial_h - 1);
let xmax = ((bbox.xmax * w_ratio) as i64).clamp(0, initial_w - 1);
let ymax = ((bbox.ymax * h_ratio) as i64).clamp(0, initial_h - 1);
Self::draw_line(image, xmin, xmax, ymin, ymax.min(ymin + 2));
Self::draw_line(image, xmin, xmax, ymin.max(ymax - 2), ymax);
Self::draw_line(image, xmin, xmax.min(xmin + 2), ymin, ymax);
Self::draw_line(image, xmin.max(xmax - 2), xmax, ymin, ymax);
}
}

pub fn save(&self, path: &str) {
tch::vision::image::save(&self.image, path).expect("can't save image");
}
}
Loading

0 comments on commit e5f068e

Please sign in to comment.