Skip to content

Commit

Permalink
add model specification in public API
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Conos committed Sep 12, 2024
1 parent ff78cd3 commit 0000121
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 5 deletions.
46 changes: 42 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ pub mod utils;

use image::{Image, ImageCHW};
use tch::{IValue, Tensor};
use utils::{DetectionTools, SegmentationTools};
use utils::{get_model, DetectionTools, SegmentationTools};

pub(crate) mod classes;

pub enum YOLOModel {
Nano,
Small,
Medium,
Large,
Extra,
}

#[derive(Debug, Clone, Copy)]
pub struct BBox {
pub xmin: f64,
Expand Down Expand Up @@ -104,9 +112,17 @@ pub struct YoloV8Classifier {
}

impl YoloV8Classifier {
pub fn with_model(model_type: YOLOModel) -> Self {
Self {
yolo: YOLOv8::new(&get_model(model_type, utils::YOLOSpec::Classification))
.expect("can't load model"),
}
}

pub fn new() -> Self {
Self {
yolo: YOLOv8::new("models/yolov8n-cls.torchscript").expect("can't load model"),
yolo: YOLOv8::new(&get_model(YOLOModel::Nano, utils::YOLOSpec::Classification))
.expect("can't load model"),
}
}

Expand Down Expand Up @@ -149,9 +165,20 @@ pub struct YoloV8ObjectDetection {
}

impl YoloV8ObjectDetection {
pub fn with_model(model_type: YOLOModel) -> Self {
Self {
yolo: YOLOv8::new(&get_model(model_type, utils::YOLOSpec::ObjectDetection))
.expect("can't load model"),
}
}

pub fn new() -> Self {
Self {
yolo: YOLOv8::new("models/yolov8n.torchscript").expect("can't load model"),
yolo: YOLOv8::new(&get_model(
YOLOModel::Nano,
utils::YOLOSpec::ObjectDetection,
))
.expect("can't load model"),
}
}

Expand Down Expand Up @@ -182,9 +209,20 @@ pub struct YoloV8Segmentation {
}

impl YoloV8Segmentation {
pub fn with_model(model_type: YOLOModel) -> Self {
Self {
yolo: YOLOv8::new(&utils::get_model(model_type, utils::YOLOSpec::Segmentation))
.expect("can't load model"),
}
}

pub fn new() -> Self {
Self {
yolo: YOLOv8::new("models/yolov8n-seg.torchscript").expect("can't load model"),
yolo: YOLOv8::new(&utils::get_model(
YOLOModel::Nano,
utils::YOLOSpec::Segmentation,
))
.expect("can't load model"),
}
}

Expand Down
144 changes: 143 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
use tch::{IValue, Tensor};

use crate::{image::ImageCHW, BBox, SegBBox, SegmentationResult};
use crate::{image::ImageCHW, BBox, SegBBox, SegmentationResult, YOLOModel};

pub(crate) enum YOLOSpec {
Classification,
ObjectDetection,
Segmentation,
}

pub(crate) fn get_model(model_type: YOLOModel, spec: YOLOSpec) -> String {
let specialization = match spec {
YOLOSpec::Classification => "-cls",
YOLOSpec::ObjectDetection => "",
YOLOSpec::Segmentation => "-seg",
};

let model = match model_type {
YOLOModel::Nano => "n",
YOLOModel::Small => "s",
YOLOModel::Medium => "m",
YOLOModel::Large => "l",
YOLOModel::Extra => "x",
};

format!("models/yolov8{model}{specialization}.torchscript")
}

pub struct SegmentationTools {}
pub struct DetectionTools {}
Expand Down Expand Up @@ -318,6 +342,8 @@ fn square64(size: i64, w: i64, h: i64) -> (i64, i64) {

#[cfg(test)]
mod test {
use crate::utils::get_model;

use super::preprocess;

#[test]
Expand Down Expand Up @@ -345,4 +371,120 @@ mod test {
assert_eq!(640, w);
tch::vision::image::save(&t, "katri_padded.jpg").expect("can't save image");
}

#[test]
fn get_model_test() {
assert_eq!(
"models/yolov8n-cls.torchscript".to_owned(),
get_model(
crate::YOLOModel::Nano,
crate::utils::YOLOSpec::Classification
)
);

assert_eq!(
"models/yolov8s-cls.torchscript".to_owned(),
get_model(
crate::YOLOModel::Small,
crate::utils::YOLOSpec::Classification
)
);

assert_eq!(
"models/yolov8m-cls.torchscript".to_owned(),
get_model(
crate::YOLOModel::Medium,
crate::utils::YOLOSpec::Classification
)
);

assert_eq!(
"models/yolov8l-cls.torchscript".to_owned(),
get_model(
crate::YOLOModel::Large,
crate::utils::YOLOSpec::Classification
)
);

assert_eq!(
"models/yolov8x-cls.torchscript".to_owned(),
get_model(
crate::YOLOModel::Extra,
crate::utils::YOLOSpec::Classification
)
);

assert_eq!(
"models/yolov8n-seg.torchscript".to_owned(),
get_model(crate::YOLOModel::Nano, crate::utils::YOLOSpec::Segmentation)
);

assert_eq!(
"models/yolov8s-seg.torchscript".to_owned(),
get_model(
crate::YOLOModel::Small,
crate::utils::YOLOSpec::Segmentation
)
);

assert_eq!(
"models/yolov8m-seg.torchscript".to_owned(),
get_model(
crate::YOLOModel::Medium,
crate::utils::YOLOSpec::Segmentation
)
);

assert_eq!(
"models/yolov8l-seg.torchscript".to_owned(),
get_model(
crate::YOLOModel::Large,
crate::utils::YOLOSpec::Segmentation
)
);

assert_eq!(
"models/yolov8x-seg.torchscript".to_owned(),
get_model(
crate::YOLOModel::Extra,
crate::utils::YOLOSpec::Segmentation
)
);

assert_eq!(
"models/yolov8n.torchscript".to_owned(),
get_model(
crate::YOLOModel::Nano,
crate::utils::YOLOSpec::ObjectDetection
)
);
assert_eq!(
"models/yolov8s.torchscript".to_owned(),
get_model(
crate::YOLOModel::Small,
crate::utils::YOLOSpec::ObjectDetection
)
);
assert_eq!(
"models/yolov8m.torchscript".to_owned(),
get_model(
crate::YOLOModel::Medium,
crate::utils::YOLOSpec::ObjectDetection
)
);
assert_eq!(
"models/yolov8l.torchscript".to_owned(),
get_model(
crate::YOLOModel::Large,
crate::utils::YOLOSpec::ObjectDetection
)
);
assert_eq!(
"models/yolov8x.torchscript".to_owned(),
get_model(
crate::YOLOModel::Extra,
crate::utils::YOLOSpec::ObjectDetection
)
);
}
}

0 comments on commit 0000121

Please sign in to comment.