-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add benches, refactor preprocess method
- Loading branch information
Michal Conos
committed
Sep 2, 2024
1 parent
12c68d7
commit e5a5022
Showing
4 changed files
with
57 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
use criterion::{black_box, criterion_group, criterion_main, Criterion}; | ||
use yolo_v8::utils::preprocess; | ||
|
||
fn bench_tensor_preprocess(c: &mut Criterion) { | ||
let image = tch::vision::image::load("images/bus.jpg").expect("can't load image"); | ||
c.bench_function("bench_tensor_preprocess", |b| { | ||
b.iter(|| preprocess(black_box(&image), black_box(640))) | ||
}); | ||
} | ||
|
||
criterion_group!(benches, bench_tensor_preprocess); | ||
criterion_main!(benches); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,112 +1,59 @@ | ||
use tch::Tensor; | ||
|
||
pub fn preprocess_torch(path: &str, square_size: i32) -> Tensor { | ||
let image = tch::vision::image::load(path).expect("can't load image"); | ||
pub fn preprocess(image: &Tensor, square_size: i64) -> Tensor { | ||
let (_, height, width) = image.size3().unwrap(); | ||
let (uw, uh) = square(square_size, width as i32, height as i32); | ||
let scaled_image = | ||
tch::vision::image::resize(&image, uw as i64, uh as i64).expect("can't resize image"); | ||
let scaled_image = Vec::<u8>::try_from(scaled_image.reshape([-1])).expect("vec"); | ||
let mut gray: Vec<u8> = vec![114; (square_size * square_size * 3) as usize]; | ||
let (uw, uh) = square64(square_size, width, height); | ||
let scaled_image = tch::vision::image::resize(&image, uw, uh).expect("can't resize image"); | ||
|
||
let gray: Vec<u8> = vec![114; (square_size * square_size * 3) as usize]; | ||
let bg = Tensor::from_slice(&gray).reshape([3, square_size, square_size]); | ||
let dh = (square_size - uh) / 2; | ||
let dw = (square_size - uw) / 2; | ||
let mut src_y = 0; | ||
if uw > uh { | ||
for y in dh..dh + uh { | ||
let line = get_hline(&scaled_image, (uw as usize, uh as usize), src_y); | ||
// println!("line={:?}", line); | ||
put_hline( | ||
&mut gray, | ||
(square_size as usize, square_size as usize), | ||
0, | ||
y as usize, | ||
line, | ||
); | ||
src_y += 1; | ||
} | ||
} | ||
if uh > uw { | ||
for y in 0..square_size { | ||
let line = get_hline(&scaled_image, (uw as usize, uh as usize), src_y); | ||
// println!("line={:?}", line); | ||
put_hline( | ||
&mut gray, | ||
(square_size as usize, square_size as usize), | ||
dw as usize, | ||
y as usize, | ||
line, | ||
); | ||
src_y += 1; | ||
} | ||
} | ||
|
||
let border = Tensor::from_slice(&gray).reshape([3, square_size as i64, square_size as i64]); | ||
tch::vision::image::save(&border, "border.jpg").expect("can't save image"); | ||
border | ||
} | ||
|
||
fn put_hline( | ||
v: &mut Vec<u8>, | ||
(w, h): (usize, usize), | ||
x_off: usize, | ||
y: usize, | ||
(r, g, b): (Vec<u8>, Vec<u8>, Vec<u8>), | ||
) { | ||
let r_off = 0; | ||
let g_off = w * h; | ||
let b_off = 2 * w * h; | ||
let mut s_off = y * w; | ||
for i in 0..r.len() { | ||
// println!("getline: y={y}, i={i}, s_off={s_off} b_off={b_off} idx={idx}"); | ||
v[r_off + x_off + s_off] = r[i]; | ||
v[g_off + x_off + s_off] = g[i]; | ||
v[b_off + x_off + s_off] = b[i]; | ||
s_off += 1; | ||
} | ||
bg.narrow(2, dw, uw).narrow(1, dh, uh).copy_(&scaled_image); | ||
bg | ||
} | ||
|
||
fn get_hline(v: &Vec<u8>, (w, h): (usize, usize), y: usize) -> (Vec<u8>, Vec<u8>, Vec<u8>) { | ||
let r_off = 0; | ||
let g_off = w * h; | ||
let b_off = 2 * w * h; | ||
let mut s_off = y * w; | ||
let mut r = vec![0; w]; | ||
let mut g = vec![0; w]; | ||
let mut b = vec![0; w]; | ||
for i in 0..w { | ||
// println!("getline: y={y}, i={i}, s_off={s_off} b_off={b_off} idx={idx}"); | ||
r[i] = v[r_off + s_off]; | ||
g[i] = v[g_off + s_off]; | ||
b[i] = v[b_off + s_off]; | ||
s_off += 1; | ||
} | ||
(r, g, b) | ||
} | ||
|
||
fn square(size: i32, w: i32, h: i32) -> (i32, i32) { | ||
fn square64(size: i64, w: i64, h: i64) -> (i64, i64) { | ||
let aspect = w as f32 / h as f32; | ||
if w > h { | ||
let tw = size; | ||
let th = (tw as f32 / aspect) as i32; | ||
let th = (tw as f32 / aspect) as i64; | ||
(tw, th) | ||
} else { | ||
let th = size; | ||
let tw = (size as f32 * aspect) as i32; | ||
let tw = (size as f32 * aspect) as i64; | ||
(tw, th) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use tch::Tensor; | ||
use super::preprocess; | ||
|
||
#[test] | ||
fn matmul() { | ||
let a = Tensor::from_slice(&[1, 1]).reshape([1, 2]); | ||
let b = Tensor::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]).reshape([2, 4]); | ||
println!("a={}", a); | ||
println!("b={}", b); | ||
let c = a.matmul(&b); | ||
println!("c={}", c); | ||
fn tensor_padding() { | ||
let image = tch::vision::image::load("images/bus.jpg").expect("can't load image"); | ||
let t = preprocess(&image, 640); | ||
println!("t={:?}", t); | ||
let r = t.size3(); | ||
assert!(r.is_ok()); | ||
let (ch, h, w) = r.unwrap(); | ||
assert_eq!(3, ch); | ||
assert_eq!(640, h); | ||
assert_eq!(640, w); | ||
tch::vision::image::save(&t, "bus_padded.jpg").expect("can't save image"); | ||
|
||
let image = tch::vision::image::load("images/katri.jpg").expect("can't load image"); | ||
|
||
let t = preprocess(&image, 640); | ||
println!("t={:?}", t); | ||
let r = t.size3(); | ||
assert!(r.is_ok()); | ||
let (ch, h, w) = r.unwrap(); | ||
assert_eq!(3, ch); | ||
assert_eq!(640, h); | ||
assert_eq!(640, w); | ||
tch::vision::image::save(&t, "katri_padded.jpg").expect("can't save image"); | ||
} | ||
} |