Skip to content

Commit

Permalink
Merge pull request #36 from surfriderfoundationeurope/feature/update_…
Browse files Browse the repository at this point in the history
…yolo

Feature/update yolo
  • Loading branch information
charlesollion authored Jan 26, 2023
2 parents ddb35b1 + 9a4104c commit 8bfdb36
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "plastic-origins"
version = "2.2.2a0"
version = "2.2.3"

description = "A package containing methods commonly used to make inferences"
repository = "https://github.com/surfriderfoundationeurope/surfnet"
Expand Down
23 changes: 4 additions & 19 deletions src/plasticorigins/detection/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,6 @@
from yolov5.utils.general import non_max_suppression
from typing import Any, Tuple, Union, Dict, Optional

# This has to be kept for now as this depends on the model training
id_categories = {
1: "Insulating material",
4: "Drum",
2: "Bottle-shaped",
3: "Can-shaped",
5: "Other packaging",
6: "Tire",
7: "Fishing net / cord",
8: "Easily namable",
9: "Unclear",
0: "Sheet / tarp / plastic bag / fragment",
}

categories_id = {v: k for k, v in id_categories.items()}
get_id = lambda cat: categories_id[cat]


def load_model(
model_path: str, device: str, conf: float = 0.35, iou: float = 0.50
Expand All @@ -62,9 +45,10 @@ def load_model(
model = yolov5.load(model_path, device=device)
model.conf = conf
model.iou = iou
model.classes = None
model.multi_label = False
model.max_det = 1000
categories_id = {v: k for k, v in model.names.items()}
model.get_id = lambda cat: categories_id[cat]

return model

Expand Down Expand Up @@ -117,7 +101,8 @@ def predict_yolo(
bboxes = voc2centerdims(bboxes)
bboxes = bboxes.astype(int)
confs = preds.confidence.values
labels = np.array(list(map(get_id, preds.name.values)))
# Converts back names to ids
labels = np.array(list(map(model.get_id, preds.name.values)))
return bboxes, confs, labels

else:
Expand Down

0 comments on commit 8bfdb36

Please sign in to comment.