From 12c68d72a230667f1b613658f24dbeaef2d356fb Mon Sep 17 00:00:00 2001 From: Michal Conos Date: Sun, 1 Sep 2024 23:11:25 +0200 Subject: [PATCH] WIP: segmentation masks, lots of refactoring needed --- examples/predict/main.rs | 31 +++++-- src/lib.rs | 196 ++++++++++++++++++++++++++++++++++----- src/utils.rs | 15 +++ 3 files changed, 211 insertions(+), 31 deletions(-) diff --git a/examples/predict/main.rs b/examples/predict/main.rs index c033a2f..a9fe9ac 100644 --- a/examples/predict/main.rs +++ b/examples/predict/main.rs @@ -1,4 +1,4 @@ -use tch::TchError; +use tch::{TchError, Tensor}; use yolo_v8::{Image, YoloV8Classifier, YoloV8ObjectDetection, YoloV8Segmentation}; fn object_detection(path: &str) { @@ -30,19 +30,36 @@ fn image_classification(path: &str) { println!("classes={:?}", classes); } -fn image_segmentation() { - let image = Image::new("images/test.jpg", YoloV8Segmentation::input_dimension()); +fn image_segmentation(path: &str) { + let image = Image::new(path, YoloV8Segmentation::input_dimension()); // Load exported torchscript for object detection let yolo = YoloV8Segmentation::new(); - let classes = yolo.predict(&image); + let segmentation = yolo.predict(&image, 0.25, 0.7); + println!("segmentation={:?}", segmentation); + let mut mask_no = 0; + for seg in segmentation { + let mask = seg.mask.reshape([-1]); + let name = seg.segbox.name; + let mut rgb = Vec::new(); + let mut vec = Vec::::try_from(&mask).unwrap(); + rgb.append(&mut vec.clone()); + rgb.append(&mut vec.clone()); + rgb.append(&mut vec); + let im = Tensor::from_slice(&rgb) + .reshape([3, 160, 160]) + .g_mul_scalar(255.); + let imgname = format!("mask-{name}-{mask_no}.jpg"); + tch::vision::image::save(&im, imgname).expect("can't save image"); + mask_no += 1; + } } // YOLOv8n for object detection in image fn main() -> Result<(), TchError> { - object_detection("images/katri.jpg"); - // image_classification("images/katri.jpg"); - // image_segmentation(); + // object_detection("images/bus.jpg"); + // image_classification("images/bus.jpg"); + image_segmentation("images/test.jpg"); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 116bd57..96f9249 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,24 @@ pub struct BBox { pub name: &'static str, } +#[derive(Debug, Clone, Copy)] +pub struct SegBBox { + pub xmin: f32, + pub ymin: f32, + pub xmax: f32, + pub ymax: f32, + pub conf: f32, + pub cls: usize, + pub cls_weight: [f32; 32], + pub name: &'static str, +} + +#[derive(Debug)] +pub struct SegmentationResult { + pub segbox: SegBBox, + pub mask: Tensor, +} + #[derive(Debug)] pub struct ClassConfidence { pub name: &'static str, @@ -44,7 +62,7 @@ pub struct YoloV8Classifier { impl YoloV8Classifier { pub fn new() -> Self { Self { - yolo: YOLOv8::new("models/yolov8x-cls.torchscript").expect("can't load model"), + yolo: YOLOv8::new("models/yolov8n-cls.torchscript").expect("can't load model"), } } @@ -104,17 +122,6 @@ impl YoloV8ObjectDetection { self.non_max_suppression(image, &pred.get(0), conf_thresh, iou_thresh) } - fn iou(&self, b1: &BBox, b2: &BBox) -> f64 { - let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); - let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); - let i_xmin = b1.xmin.max(b2.xmin); - let i_xmax = b1.xmax.min(b2.xmax); - let i_ymin = b1.ymin.max(b2.ymin); - let i_ymax = b1.ymax.min(b2.ymax); - let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); - i_area / (b1_area + b2_area - i_area) - } - fn non_max_suppression( &self, image: &Image, @@ -187,7 +194,7 @@ impl YoloV8ObjectDetection { for index in 0..bboxes_for_class.len() { let mut drop = false; for prev_index in 0..current_index { - let iou = self.iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]); + let iou = YOLOv8::iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]); if iou > iou_thresh { drop = true; @@ -225,8 +232,14 @@ impl YoloV8Segmentation { } } - pub fn predict(&self, image: &Image) { + pub fn predict( + &self, + image: &Image, + conf_threshold: f32, + iou_threshold: f32, + ) -> Vec { let img = &image.scaled_image; + let mut result = Vec::new(); // println!("img={:?}", img); @@ -239,21 +252,133 @@ impl YoloV8Segmentation { let t = tch::IValue::Tensor(img); let pred = self.yolo.model.forward_is(&[t]).unwrap(); println!("pred={:?}", pred); - + // https://github.com/ultralytics/ultralytics/issues/2953 if let IValue::Tuple(iv) = pred { + let mut segboxes = Vec::new(); + if let IValue::Tensor(bboxes) = &iv[0] { + let t = bboxes.get(0); + println!("bboxes={:?}", t); + segboxes = self.non_max_suppression(image, &t, conf_threshold, iou_threshold); + println!("r={:?}", segboxes); + } + if let IValue::Tensor(seg) = &iv[1] { - let t = seg.get(0); - println!("seg={:?}", t); - let (nclass, w, h) = t.size3().unwrap(); - for i in 0..nclass { - let img = t.get(i); - let mut vec: Vec = vec![0.0; (img.size()[0] * img.size()[1]) as usize]; - let l = vec.len(); - img.copy_data(&mut vec, l); - println!("i={i}, v={:?}", vec); + for segbox in segboxes { + let weights = Tensor::from_slice(&segbox.cls_weight).reshape([1, 32]); + println!("weights={:?}", weights); + + let t = seg.get(0).reshape([32, 160 * 160]); + println!("seg={:?}", t); + let mask = weights.matmul(&t).reshape([1, 160, 160]).gt_(0.0); + println!("r={}", mask); + result.push(SegmentationResult { segbox, mask }); + } + } + } + result + } + + fn non_max_suppression( + &self, + image: &Image, + prediction: &tch::Tensor, + conf_thresh: f32, + iou_thresh: f32, + ) -> Vec { + let prediction = prediction.transpose(1, 0); + let (anchors, classes_no) = prediction.size2().unwrap(); + + let nclasses = (classes_no - 4) as usize; + println!("classes_no={classes_no}, anchors={anchors}"); + + let mut bboxes: Vec> = (0..nclasses).map(|_| vec![]).collect(); + + for index in 0..anchors { + let pred = Vec::::try_from(prediction.get(index)).expect("wrong type of tensor"); + + // println!("index={index}, pred={}", pred.len()); + + //FIXME + let weights: [f32; 32] = pred[84..116].try_into().expect("cccc"); + + for i in 4..84 as usize { + let confidence = pred[i]; + if confidence > conf_thresh { + let class_index = i - 4; + // println!( + // "confidence={confidence}, class_index={class_index} class_name={}", + // CLASSES[class_index] + // ); + + let (_, orig_h, orig_w) = image.image.size3().unwrap(); + let (_, sh, sw) = image.scaled_image.size3().unwrap(); + let cx = sw as f32 / 2.0; + let cy = sh as f32 / 2.0; + let mut dx = pred[0] - cx; + let mut dy = pred[1] - cy; + let mut w = pred[2]; + let mut h = pred[3]; + + let aspect = orig_w as f32 / orig_h as f32; + + if orig_w > orig_h { + dy *= aspect; + h *= aspect; + } else { + dx /= aspect; + w /= aspect; + } + + let x = cx + dx; + let y = cy + dy; + + let bbox = SegBBox { + xmin: x - w / 2., + ymin: y - h / 2., + xmax: x + w / 2., + ymax: y + h / 2., + conf: confidence, + cls: class_index, + name: DETECT_CLASSES[class_index], + cls_weight: weights, + }; + bboxes[class_index].push(bbox) } } } + + for bboxes_for_class in bboxes.iter_mut() { + bboxes_for_class.sort_by(|b1, b2| b2.conf.partial_cmp(&b1.conf).unwrap()); + + let mut current_index = 0; + for index in 0..bboxes_for_class.len() { + let mut drop = false; + for prev_index in 0..current_index { + let iou = + YOLOv8::iou_seg(&bboxes_for_class[prev_index], &bboxes_for_class[index]); + + if iou > iou_thresh { + drop = true; + break; + } + } + if !drop { + bboxes_for_class.swap(current_index, index); + current_index += 1; + } + } + bboxes_for_class.truncate(current_index); + } + + let mut result = vec![]; + + for bboxes_for_class in bboxes.iter() { + for bbox in bboxes_for_class.iter() { + result.push(*bbox); + } + } + + return result; } pub fn input_dimension() -> (i64, i64) { @@ -290,6 +415,29 @@ impl YOLOv8 { pred } + + fn iou(b1: &BBox, b2: &BBox) -> f64 { + let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); + let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); + let i_xmin = b1.xmin.max(b2.xmin); + let i_xmax = b1.xmax.min(b2.xmax); + let i_ymin = b1.ymin.max(b2.ymin); + let i_ymax = b1.ymax.min(b2.ymax); + let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); + i_area / (b1_area + b2_area - i_area) + } + + //FIXME !!! + fn iou_seg(b1: &SegBBox, b2: &SegBBox) -> f32 { + let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); + let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); + let i_xmin = b1.xmin.max(b2.xmin); + let i_xmax = b1.xmax.min(b2.xmax); + let i_ymin = b1.ymin.max(b2.ymin); + let i_ymax = b1.ymax.min(b2.ymax); + let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); + i_area / (b1_area + b2_area - i_area) + } } pub struct Image { diff --git a/src/utils.rs b/src/utils.rs index aa39794..cef698d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -95,3 +95,18 @@ fn square(size: i32, w: i32, h: i32) -> (i32, i32) { (tw, th) } } + +#[cfg(test)] +mod test { + use tch::Tensor; + + #[test] + fn matmul() { + let a = Tensor::from_slice(&[1, 1]).reshape([1, 2]); + let b = Tensor::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]).reshape([2, 4]); + println!("a={}", a); + println!("b={}", b); + let c = a.matmul(&b); + println!("c={}", c); + } +}