From 04eb7c203c67021b1d3efba66a4078862bf75b1b Mon Sep 17 00:00:00 2001 From: Christian M Date: Tue, 26 Sep 2023 18:07:31 +0200 Subject: [PATCH] :sparkles::bento: adds efficientnet pretrained model --- README.md | 6 ++++++ assets/Birds-Classifier-EfficientNetB2.onnx | 1 + scripts/README.md | 6 ++++++ scripts/download_efficientnet.py | 3 +++ scripts/train.py | 2 -- src/efficientnet.rs | 6 ++++-- 6 files changed, 20 insertions(+), 4 deletions(-) create mode 120000 assets/Birds-Classifier-EfficientNetB2.onnx create mode 100644 scripts/download_efficientnet.py diff --git a/README.md b/README.md index 4a2567a..d01ab5a 100644 --- a/README.md +++ b/README.md @@ -99,3 +99,9 @@ Find an example in [example.ipynb](example.ipynb) pip install jupyter jupyter notebook example.ipynb ``` + +### Thanks + +Thanks to [gpiosenka](https://www.kaggle.com/gpiosenka/100-bird-species) for the dataset. + +Thanks to [dennisjooo](https://huggingface.co/dennisjooo/Birds-Classifier-EfficientNetB2) for the efficientnet model. diff --git a/assets/Birds-Classifier-EfficientNetB2.onnx b/assets/Birds-Classifier-EfficientNetB2.onnx new file mode 120000 index 0000000..06baab8 --- /dev/null +++ b/assets/Birds-Classifier-EfficientNetB2.onnx @@ -0,0 +1 @@ +../../../../.cache/huggingface/hub/models--dennisjooo--Birds-Classifier-EfficientNetB2/blobs/179d9b6d4229355b0b3d9d1a0cb144971d2dfdc70cf7c68ed38ee73a17991e65 \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md index 55b9920..9297cda 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -36,3 +36,9 @@ python export_onnx_model.py cp birds_mobilenetv2.onnx ../assets/ cp birds_labels.txt ../assets/ ``` + +## Download EfficientNetB2 + +```sh +python download_efficientnet.py +``` diff --git a/scripts/download_efficientnet.py b/scripts/download_efficientnet.py new file mode 100644 index 0000000..dd9a120 --- /dev/null +++ b/scripts/download_efficientnet.py @@ -0,0 +1,3 @@ +# Importing the libraries needed +from huggingface_hub import hf_hub_download +hf_hub_download(repo_id="dennisjooo/Birds-Classifier-EfficientNetB2", filename="Birds-Classifier-EfficientNetB2.onnx", local_dir="../assets/") \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index 03b34fa..1f4e535 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -100,9 +100,7 @@ def get_birds_mobilenet(): checkpoint_path = "./checkpoints/birds_mobilenet/" -# Check if checkpoint directory exists and contains saved_model.pb if os.path.exists(checkpoint_path) and os.path.isfile(os.path.join(checkpoint_path, 'saved_model.pb')): - # If SavedModel exists, load the entire model model = tf.keras.models.load_model(checkpoint_path) print(f"Loaded model from {checkpoint_path}") else: diff --git a/src/efficientnet.rs b/src/efficientnet.rs index 4fe56a4..767b8c7 100644 --- a/src/efficientnet.rs +++ b/src/efficientnet.rs @@ -8,7 +8,7 @@ const SIZE: usize = 260; impl Birds { pub fn model() -> Result> { - let data = include_bytes!("../assets/birds_efficientnetb2.onnx"); + let data = include_bytes!("../assets/Birds-Classifier-EfficientNetB2.onnx"); let mut cursor = Cursor::new(data); let model = tract_onnx::onnx() .model_for_read(&mut cursor)? @@ -43,7 +43,9 @@ impl Birds { ); let tensor: Tensor = tract_ndarray::Array4::from_shape_fn((1, 3, SIZE, SIZE), |(_, c, y, x)| { - (resized[(x as _, y as _)][c] as f32 / 255.0) + let mean = [0.485, 0.456, 0.406][c]; + let std = [0.47853944, 0.4732864, 0.47434163][c]; + (resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std }) .into();