Skip to content

Commit

Permalink
✨🍱 adds efficientnet pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Sep 26, 2023
1 parent 9e40589 commit 04eb7c2
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 4 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,9 @@ Find an example in [example.ipynb](example.ipynb)
pip install jupyter
jupyter notebook example.ipynb
```

### Thanks

Thanks to [gpiosenka](https://www.kaggle.com/gpiosenka/100-bird-species) for the dataset.

Thanks to [dennisjooo](https://huggingface.co/dennisjooo/Birds-Classifier-EfficientNetB2) for the efficientnet model.
1 change: 1 addition & 0 deletions assets/Birds-Classifier-EfficientNetB2.onnx
6 changes: 6 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ python export_onnx_model.py
cp birds_mobilenetv2.onnx ../assets/
cp birds_labels.txt ../assets/
```

## Download EfficientNetB2

```sh
python download_efficientnet.py
```
3 changes: 3 additions & 0 deletions scripts/download_efficientnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Importing the libraries needed
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="dennisjooo/Birds-Classifier-EfficientNetB2", filename="Birds-Classifier-EfficientNetB2.onnx", local_dir="../assets/")
2 changes: 0 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def get_birds_mobilenet():

checkpoint_path = "./checkpoints/birds_mobilenet/"

# Check if checkpoint directory exists and contains saved_model.pb
if os.path.exists(checkpoint_path) and os.path.isfile(os.path.join(checkpoint_path, 'saved_model.pb')):
# If SavedModel exists, load the entire model
model = tf.keras.models.load_model(checkpoint_path)
print(f"Loaded model from {checkpoint_path}")
else:
Expand Down
6 changes: 4 additions & 2 deletions src/efficientnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const SIZE: usize = 260;

impl Birds {
pub fn model() -> Result<ModelType, Box<dyn std::error::Error>> {
let data = include_bytes!("../assets/birds_efficientnetb2.onnx");
let data = include_bytes!("../assets/Birds-Classifier-EfficientNetB2.onnx");
let mut cursor = Cursor::new(data);
let model = tract_onnx::onnx()
.model_for_read(&mut cursor)?
Expand Down Expand Up @@ -43,7 +43,9 @@ impl Birds {
);
let tensor: Tensor =
tract_ndarray::Array4::from_shape_fn((1, 3, SIZE, SIZE), |(_, c, y, x)| {
(resized[(x as _, y as _)][c] as f32 / 255.0)
let mean = [0.485, 0.456, 0.406][c];
let std = [0.47853944, 0.4732864, 0.47434163][c];
(resized[(x as _, y as _)][c] as f32 / 255.0 - mean) / std
})
.into();

Expand Down

0 comments on commit 04eb7c2

Please sign in to comment.