Skip to content

Commit

Permalink
Update object_detection.py (#93)
Browse files Browse the repository at this point in the history
Co-authored-by: Tanisha Lalwani <[email protected]>
  • Loading branch information
Charu19awasthi and tanishaness authored Nov 4, 2024
1 parent 2bf2f72 commit 2f757a7
Showing 1 changed file with 65 additions and 52 deletions.
117 changes: 65 additions & 52 deletions Backend/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,78 +4,91 @@
import tensorflow_hub as hub
import json

# Function to load the SSD MobileNet V2 model
def load_detector_model():
return hub.load("https://www.kaggle.com/models/tensorflow/ssd-mobilenet-v2/TensorFlow2/fpnlite-320x320/1")
# Load the SSD MobileNet V2 model from TensorFlow Hub
def load_model():
try:
model = hub.load("https://tfhub.dev/tensorflow/ssd_mobilenet_v2/2")
print("Model loaded successfully.")
return model
except Exception as e:
print("Error loading model:", e)
return None

# Load class names from an external file
with open("class_names.json") as f:
class_names = json.load(f)
def load_class_names(filename="class_names.json"):
with open(filename) as f:
return json.load(f)

# Function to perform object detection
# Function to perform object detection on an image tensor
def detect_objects(model, image):
# Resize image and maintain aspect ratio
# Resize and prepare image tensor
h, w = image.shape[:2]
scale_factor = 320 / max(h, w)
image_resized = cv2.resize(image, (int(w * scale_factor), int(h * scale_factor)))
input_tensor = tf.convert_to_tensor(image_resized[tf.newaxis, ...], dtype=tf.uint8)

# Perform inference
input_tensor = tf.convert_to_tensor(image_resized, dtype=tf.uint8)
input_tensor = input_tensor[tf.newaxis, ...]

# Run inference
return model(input_tensor)

# Load and preprocess the image
def load_image(image_path):
image = cv2.imread(image_path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Draw boxes with information overlay
def draw_boxes(image, boxes, class_ids, scores, threshold=0.5):
# Draw bounding boxes and labels on the image
def draw_boxes(image, boxes, class_ids, scores, class_names, threshold=0.5):
h, w, _ = image.shape
detected_items = set()
detected_items = []

for box, class_id, score in zip(boxes, class_ids, scores):
if score >= threshold:
# Rescale bounding box coordinates
for i in range(len(scores)):
if scores[i] >= threshold:
box = boxes[i]
ymin, xmin, ymax, xmax = [int(val * dim) for val, dim in zip(box, [h, w, h, w])]

class_name = class_names.get(str(class_ids[i]), "Unknown")
confidence = scores[i] * 100

# Retrieve class name
class_name = class_names.get(str(class_id), 'Unknown')
print(f"Detected: {class_name}, Confidence: {score:.2%}")

# Draw bounding box and label on image
# Draw bounding box and label
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
cv2.putText(image, f"{class_name} ({score:.2f})", (xmin, ymin - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

# Track detected specific items
if class_id in {64, 67}: # 64: 'Laptop', 67: 'Cell Phone'
detected_items.add(class_name)
label = f"{class_name}: {confidence:.2f}%"
cv2.putText(image, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
detected_items.append(class_name)

# Log the detected items if any
print("Found:", ", ".join(detected_items) if detected_items else "None")
print("Detected items:", ", ".join(set(detected_items)))
return image

# Main function to execute the detection
def main(image_path, threshold=0.5):
model = load_detector_model()
# Load and preprocess the image
def load_image(image_path):
image = cv2.imread(image_path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image is not None else None

# Main function for object detection
def main(image_path, class_names_file="class_names.json", threshold=0.5, save_output=False):
# Load model and class names
model = load_model()
class_names = load_class_names(class_names_file)

# Load image
image = load_image(image_path)
detector_output = detect_objects(model, image)
if image is None:
print("Error: Could not load image.")
return

# Detect objects
detection = detect_objects(model, image)

# Extract bounding boxes, classes, and scores
boxes = detector_output["detection_boxes"].numpy()[0]
class_ids = detector_output["detection_classes"].numpy()[0].astype(int)
scores = detector_output["detection_scores"].numpy()[0]
# Extract detection data
boxes = detection["detection_boxes"].numpy()[0]
class_ids = detection["detection_classes"].numpy()[0].astype(int)
scores = detection["detection_scores"].numpy()[0]

# Draw boxes on the image
image_with_boxes = draw_boxes(image, boxes, class_ids, scores, threshold)
# Draw boxes on image
image_with_boxes = draw_boxes(image, boxes, class_ids, scores, class_names, threshold)

# Display the image if environment allows
if "DISPLAY" in os.environ: # Check if display is available (e.g., on local machines)
cv2.imshow("Object Detection", cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)
cv2.destroyAllWindows()
# Show and optionally save the output
cv2.imshow("Object Detection", cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
if save_output:
cv2.imwrite("output_detected.jpg", cv2.cvtColor(image_with_boxes, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)
cv2.destroyAllWindows()

# Run the main function
# Run detection with threshold and save option
if __name__ == "__main__":
image_path = "test-image2.jpg" # Change this to your image path
main(image_path)
image_path = "test-image2.jpg" # Adjust image path as needed
main(image_path, threshold=0.5, save_output=True)

0 comments on commit 2f757a7

Please sign in to comment.