Skip to content

Commit

Permalink
Update object_detection.py (#81)
Browse files Browse the repository at this point in the history
Here's a brief analysis of the provided code, focusing on key issues and improvements:

1. **Model Loading**:
   - The code loads the object detection model globally. This is fine for single-use scripts but could cause inefficiencies if this script is reused in multiple contexts. Consider wrapping the model load inside a function for modularity and efficiency.

2. **Image Resizing**:
   - The `detect_objects` function resizes images directly to the model’s input size without preserving the aspect ratio, which may impact detection accuracy. Implementing aspect-ratio-preserving resize with padding would improve detection performance.

3. **Object Detection Inference**:
   - Tensor conversion involves resizing the image and adding a batch dimension manually, which can be streamlined by specifying dimensions directly upon tensor creation.

4. **Class Mapping**:
   - The code includes hard-coded class names for COCO objects, making it hard to update or extend. Externalizing this list (e.g., in a JSON file) would improve maintainability.

5. **Bounding Box Drawing and Item Detection**:
   - The function `draw_boxes` mixes image annotation (drawing boxes) with printing/logging, which reduces separation of concerns. Moving logging to a dedicated function could improve readability.

6. **Image Display and Environment Compatibility**:
   - The code calls `cv2.imshow()` and `cv2.waitKey(0)` for displaying images, which may not work in all environments (e.g., headless servers or Jupyter notebooks). A conditional display check would enhance compatibility.

### Summary of Key Improvements
- Modularize model loading to allow reuse.
- Preserve image aspect ratio when resizing.
- Streamline tensor conversion for efficiency.
- Externalize class mapping to enhance code maintainability.
- Separate logging and drawing functions.
- Use a display condition to improve compatibility across environments.

These adjustments would make the code cleaner, more efficient, and easier to maintain, while preserving the original functionality.
  • Loading branch information
Khushi-Dua authored Nov 4, 2024
1 parent 2de8e58 commit 3265ecf
Showing 1 changed file with 60 additions and 148 deletions.
208 changes: 60 additions & 148 deletions Backend/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,168 +2,80 @@
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import json

# Load the SSD MobileNet V2 model from TensorFlow Hub
detector = hub.load("https://www.kaggle.com/models/tensorflow/ssd-mobilenet-v2/TensorFlow2/fpnlite-320x320/1")

# Mapping class IDs to class names for COCO dataset
class_names = {
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
12: 'stop sign',
13: 'parking meter',
14: 'bench',
15: 'bird',
16: 'cat',
17: 'dog',
18: 'horse',
19: 'sheep',
20: 'cow',
21: 'elephant',
22: 'bear',
23: 'zebra',
24: 'giraffe',
25: 'backpack',
26: 'umbrella',
27: 'handbag',
28: 'tie',
29: 'suitcase',
30: 'frisbee',
31: 'skis',
32: 'snowboard',
33: 'sports ball',
34: 'kite',
35: 'baseball bat',
36: 'baseball glove',
37: 'skateboard',
38: 'surfboard',
39: 'tennis racket',
40: 'bottle',
41: 'wine glass',
42: 'cup',
43: 'fork',
44: 'knife',
45: 'spoon',
46: 'bowl',
47: 'banana',
48: 'apple',
49: 'sandwich',
50: 'orange',
51: 'broccoli',
52: 'carrot',
53: 'hot dog',
54: 'pizza',
55: 'donut',
56: 'cake',
57: 'chair',
58: 'couch',
59: 'potted plant',
60: 'bed',
61: 'dining table',
62: 'toilet',
63: 'TV',
64: 'laptop',
65: 'mouse',
66: 'remote',
67: 'keyboard',
68: 'cell phone',
69: 'microwave',
70: 'oven',
71: 'toaster',
72: 'sink',
73: 'refrigerator',
74: 'book',
75: 'clock',
76: 'vase',
77: 'scissors',
78: 'teddy bear',
79: 'hair drier',
80: 'toothbrush'
}
# Function to load the SSD MobileNet V2 model
def load_detector_model():
return hub.load("https://www.kaggle.com/models/tensorflow/ssd-mobilenet-v2/TensorFlow2/fpnlite-320x320/1")

# Load class names from an external file
with open("class_names.json") as f:
class_names = json.load(f)

# Function to perform object detection
def detect_objects(image):
# Convert the image to a tensor and resize it
image_resized = cv2.resize(image, (320, 320)) # Resize to model input size
input_tensor = tf.convert_to_tensor(image_resized)
input_tensor = input_tensor[tf.newaxis, ...] # Add batch dimension
def detect_objects(model, image):
# Resize image and maintain aspect ratio
h, w = image.shape[:2]
scale_factor = 320 / max(h, w)
image_resized = cv2.resize(image, (int(w * scale_factor), int(h * scale_factor)))
input_tensor = tf.convert_to_tensor(image_resized[tf.newaxis, ...], dtype=tf.uint8)

# Perform inference
detector_output = detector(input_tensor)

return detector_output
return model(input_tensor)

# Load and preprocess the image
def load_image(image_path):
# Load an image from a file path
image = cv2.imread(image_path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Draw boxes with information overlay
def draw_boxes(image, boxes, class_ids, scores, threshold=0.5):
height, width, _ = image.shape
found_items = [] # List to store found items

for i in range(len(scores)):
if scores[i] > threshold:
box = boxes[i]
(ymin, xmin, ymax, xmax) = box
# Scale bounding box to original image dimensions
xmin = int(xmin * width)
xmax = int(xmax * width)
ymin = int(ymin * height)
ymax = int(ymax * height)

# Get class name from class ID
class_name = class_names.get(class_ids[i], 'Unknown')

# Print class name and confidence score
print(f'Detected Class: {class_name}, Confidence: {scores[i] * 100:.2f}%')

# Check if the detected object is a laptop or a cell phone
if class_ids[i] == 64: # Class ID for 'laptop'
found_items.append('Laptop')
elif class_ids[i] == 67: # Class ID for 'cell phone'
found_items.append('Cell Phone')

# Draw bounding box
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
cv2.putText(image, f'{class_name} ({scores[i]:.2f})',
(xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

# Print found items
if found_items:
print("Found:", ", ".join(set(found_items))) # Print unique found items
else:
print("None")
h, w, _ = image.shape
detected_items = set()

return image
for box, class_id, score in zip(boxes, class_ids, scores):
if score >= threshold:
# Rescale bounding box coordinates
ymin, xmin, ymax, xmax = [int(val * dim) for val, dim in zip(box, [h, w, h, w])]

# Main function to execute the detection
def main(image_path):
image = load_image(image_path)
detector_output = detect_objects(image)
# Retrieve class name
class_name = class_names.get(str(class_id), 'Unknown')
print(f"Detected: {class_name}, Confidence: {score:.2%}")

# Extract data from detector output
boxes = detector_output["detection_boxes"].numpy()[0] # Bounding boxes
class_ids = detector_output["detection_classes"].numpy()[0].astype(int) # Class IDs
scores = detector_output["detection_scores"].numpy()[0] # Confidence scores
# Draw bounding box and label on image
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
cv2.putText(image, f"{class_name} ({score:.2f})", (xmin, ymin - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

# Draw bounding boxes on the image
image_with_boxes = draw_boxes(image, boxes, class_ids, scores)
# Track detected specific items
if class_id in {64, 67}: # 64: 'Laptop', 67: 'Cell Phone'
detected_items.add(class_name)

# Display the image
cv2.imshow('Object Detection', cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)
cv2.destroyAllWindows()
# Log the detected items if any
print("Found:", ", ".join(detected_items) if detected_items else "None")
return image

# Call the main function with the path to your image
image_path = "test-image2.jpg" # Change this to your image path
main(image_path)
# Main function to execute the detection
def main(image_path, threshold=0.5):
model = load_detector_model()
image = load_image(image_path)
detector_output = detect_objects(model, image)

# Extract bounding boxes, classes, and scores
boxes = detector_output["detection_boxes"].numpy()[0]
class_ids = detector_output["detection_classes"].numpy()[0].astype(int)
scores = detector_output["detection_scores"].numpy()[0]

# Draw boxes on the image
image_with_boxes = draw_boxes(image, boxes, class_ids, scores, threshold)

# Display the image if environment allows
if "DISPLAY" in os.environ: # Check if display is available (e.g., on local machines)
cv2.imshow("Object Detection", cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)
cv2.destroyAllWindows()

# Run the main function
if __name__ == "__main__":
image_path = "test-image2.jpg" # Change this to your image path
main(image_path)

0 comments on commit 3265ecf

Please sign in to comment.