diff --git a/README.md b/README.md index 1febb70..2f7e42d 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ Configuration variables: - **save_file_folder**: (Optional) The folder to save processed images to. Note that folder path should be added to [whitelist_external_dirs](https://www.home-assistant.io/docs/configuration/basic/) - **save_timestamped_file**: (Optional, default `False`, requires `save_file_folder` to be configured) Save the processed image with the time of detection in the filename. - **source**: Must be a camera. -- **target**: The target object class, default `person`. +- **target**: The target object class, default `person`. Can also be a list of targets. - **confidence**: (Optional) The confidence (in %) above which detected targets are counted in the sensor state. Default value: 80 - **name**: (Optional) A custom name for the the entity. diff --git a/custom_components/deepstack_object/image_processing.py b/custom_components/deepstack_object/image_processing.py index bd11be2..40ed688 100644 --- a/custom_components/deepstack_object/image_processing.py +++ b/custom_components/deepstack_object/image_processing.py @@ -50,7 +50,7 @@ CONF_SAVE_TIMESTAMPTED_FILE = "save_timestamped_file" DATETIME_FORMAT = "%Y-%m-%d_%H:%M:%S" DEFAULT_API_KEY = "" -DEFAULT_TARGET = "person" +DEFAULT_TARGET = ["person"] DEFAULT_TIMEOUT = 10 EVENT_OBJECT_DETECTED = "image_processing.object_detected" EVENT_FILE_SAVED = "image_processing.file_saved" @@ -67,7 +67,9 @@ vol.Required(CONF_PORT): cv.port, vol.Optional(CONF_API_KEY, default=DEFAULT_API_KEY): cv.string, vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int, - vol.Optional(CONF_TARGET, default=DEFAULT_TARGET): cv.string, + vol.Optional(CONF_TARGET, default=DEFAULT_TARGET): vol.All( + cv.ensure_list, [cv.string] + ), vol.Optional(CONF_SAVE_FILE_FOLDER): cv.isdir, vol.Optional(CONF_SAVE_TIMESTAMPTED_FILE, default=False): cv.boolean, } @@ -76,8 +78,9 @@ def get_box(prediction: dict, img_width: int, img_height: int): """ - Return the relative bounxing box coordinates - defined by the tuple (y_min, x_min, y_max, x_max) + Return the relative bounxing box coordinates. + + Defined by the tuple (y_min, x_min, y_max, x_max) where the coordinates are floats in the range [0.0, 1.0] and relative to the width and height of the image. """ @@ -145,7 +148,8 @@ def __init__( camera_name = split_entity_id(camera_entity)[1] self._name = "deepstack_object_{}".format(camera_name) self._state = None - self._targets_confidences = [] + self._targets_confidences = [None] * len(self._target) + self._targets_found = [0] * len(self._target) self._predictions = {} self._summary = {} self._last_detection = None @@ -161,27 +165,30 @@ def process_image(self, image): io.BytesIO(bytearray(image)) ).size self._state = None - self._targets_confidences = [] + self._targets_confidences = [None] * len(self._target) + self._targets_found = [0] * len(self._target) self._predictions = {} self._summary = {} try: self._dsobject.detect(image) except ds.DeepstackException as exc: - _LOGGER.error("Depstack error : %s", exc) + _LOGGER.error("Deepstack error : %s", exc) return self._predictions = self._dsobject.predictions.copy() - if len(self._predictions) > 0: - raw_confidences = ds.get_object_confidences(self._predictions, self._target) - self._targets_confidences = [ - ds.format_confidence(confidence) for confidence in raw_confidences - ] - self._state = len( - ds.get_confidences_above_threshold( - self._targets_confidences, self._confidence + if self._predictions: + for i, target in enumerate(self._target): + raw_confidences = ds.get_object_confidences(self._predictions, target) + self._targets_confidences[i] = [ + ds.format_confidence(confidence) for confidence in raw_confidences + ] + self._targets_found[i] = len( + ds.get_confidences_above_threshold( + self._targets_confidences[i], self._confidence + ) ) - ) + self._state = sum(self._targets_found) if self._state > 0: self._last_detection = dt_util.now().strftime(DATETIME_FORMAT) self._summary = ds.get_objects_summary(self._predictions) @@ -193,14 +200,13 @@ def process_image(self, image): def save_image(self, image, predictions, target, directory): """Save a timestamped image with bounding boxes around targets.""" - img = Image.open(io.BytesIO(bytearray(image))).convert("RGB") draw = ImageDraw.Draw(img) for prediction in predictions: prediction_confidence = ds.format_confidence(prediction["confidence"]) if ( - prediction["label"] == target + prediction["label"] in target and prediction_confidence >= self._confidence ): box = get_box(prediction, self._image_width, self._image_height) @@ -213,12 +219,12 @@ def save_image(self, image, predictions, target, directory): color=RED, ) - latest_save_path = directory + "{}_latest_{}.jpg".format(self._name, target) + latest_save_path = directory + "{}_latest_{}.jpg".format(self._name, target[0]) img.save(latest_save_path) if self._save_timestamped_file: timestamp_save_path = directory + "{}_{}_{}.jpg".format( - self._name, target, self._last_detection + self._name, target[0], self._last_detection ) out_file = open(timestamp_save_path, "wb") @@ -231,7 +237,6 @@ def save_image(self, image, predictions, target, directory): def fire_prediction_events(self, predictions, confidence): """Fire events based on predictions if above confidence threshold.""" - for prediction in predictions: if ds.format_confidence(prediction["confidence"]) > confidence: box = get_box(prediction, self._image_width, self._image_height) @@ -246,7 +251,7 @@ def fire_prediction_events(self, predictions, confidence): ) def fire_saved_file_event(self, save_path): - """Fire event when saving a file""" + """Fire event when saving a file.""" self.hass.bus.fire( EVENT_FILE_SAVED, {ATTR_ENTITY_ID: self.entity_id, FILE: save_path} ) @@ -269,9 +274,9 @@ def name(self): @property def unit_of_measurement(self): """Return the unit of measurement.""" - target = self._target - if self._state != None and self._state > 1: - target += "s" + target = self._target if len(self._target) == 1 else "target" + if self._state is not None and self._state > 1: + return target + "s" return target @property @@ -279,6 +284,6 @@ def device_state_attributes(self): """Return device specific state attributes.""" attr = {} if self._last_detection: - attr["last_{}_detection".format(self._target)] = self._last_detection + attr["last_detection"] = self._last_detection attr["summary"] = self._summary return attr