Skip to content

Commit

Permalink
Merge pull request #19 from rcland12/jetson_develop
Browse files Browse the repository at this point in the history
Jetson develop
  • Loading branch information
rcland12 authored Mar 20, 2024
2 parents fdc3117 + 97ea394 commit 47958e1
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 155 deletions.
4 changes: 2 additions & 2 deletions app/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,5 @@ RUN python3 -m pip install --upgrade --no-cache-dir pip \
&& rm -rf ${BUILDDIR}/tritonserver/

WORKDIR ${APPDIR}
COPY main.py utilities.py ./
CMD ["python3", "main.py"]
COPY main.py utilities.py stream.sh entrypoint.sh ./
CMD ["/bin/bash", "entrypoint.sh"]
9 changes: 9 additions & 0 deletions app/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash
if [ "${OBJECT_DETECTION}" == "True" ]; then
python3 main.py
elif [ "${OBJECT_DETECTION}" == "False" ]; then
/bin/bash stream.sh
else
echo "Invalid input for OBJECT_DETECTION. Expecting True or False; received ${OBJECT_DETECTION}."
exit 120
fi
141 changes: 65 additions & 76 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import cv2
import torch
import numpy
import typing
import imutils
import subprocess
import numpy as np

from nanocamera import Camera
from dotenv import load_dotenv
Expand Down Expand Up @@ -120,8 +120,6 @@ class ObjectDetection():
def __init__(
self,
model_name,
camera_width,
camera_height,
triton_url
):

Expand All @@ -130,12 +128,9 @@ def __init__(
except ConnectionError as e:
raise f"Failed to connect to Triton: {e}"

self.frame_dims = [camera_width, camera_height]

def __call__(self, frame):
predictions = self.model(
frame,
numpy.array(self.frame_dims, dtype='int16')
frame
).tolist()

bboxes = [item[:4] for item in predictions]
Expand All @@ -146,38 +141,45 @@ def __call__(self, frame):


class Annotator():
def __init__(self, classes, width=1280, height=720):
self.classes = classes
def __init__(self, classes, width=1280, height=720, santa_hat_plugin_bool=False):
self.width = width
self.height = height
self.colors = list(numpy.random.rand(len(self.classes), 3) * 255)
self.classes = classes
self.colors = list(np.random.rand(len(self.classes), 3) * 255)
self.santa_hat = cv2.imread("images/santa_hat.png")
self.santa_hat_mask = cv2.imread("images/santa_hat_mask.png")
self.santa_hat_plugin_bool = santa_hat_plugin_bool

def __call__(self, frame, bboxes, confs, indexes):
for i in range(len(bboxes)):
xmin, ymin, xmax, ymax = [int(j) for j in bboxes[i]]
color = self.colors[indexes[i]]
frame = cv2.rectangle(
img=frame,
pt1=(xmin, ymin),
pt2=(xmax, ymax),
color=color,
thickness=2
)

frame = cv2.putText(
img=frame,
text=f'{self.classes[indexes[i]]} ({str(confs[i])})',
org=(xmin, ymin),
fontFace=cv2.FONT_HERSHEY_PLAIN,
fontScale=0.75,
color=color,
thickness=1,
lineType=cv2.LINE_AA
)
if not self.santa_hat_plugin_bool:
for i in range(len(bboxes)):
xmin, ymin, xmax, ymax = [int(j) for j in bboxes[i]]
color = self.colors[indexes[i]]
frame = cv2.rectangle(
img=frame,
pt1=(xmin, ymin),
pt2=(xmax, ymax),
color=color,
thickness=2
)

frame = cv2.putText(
img=frame,
text=f'{self.classes[indexes[i]]} ({str(confs[i])})',
org=(xmin, ymin - 5),
fontFace=cv2.FONT_HERSHEY_PLAIN,
fontScale=0.75,
color=color,
thickness=1,
lineType=cv2.LINE_AA
)

return frame

return frame
else:
# For santa hat plugin, turn Normalize to True in nms function
max_index = max(range(len(confs)), key=confs.__getitem__)
return self._overlay_obj(frame, bboxes[max_index].copy())

def _overlay_obj(self, frame, bbox):
bbox = [int(i * scalar) for i, scalar in zip(bbox, [self.width, self.height, self.width, self.height])]
Expand All @@ -189,7 +191,7 @@ def _overlay_obj(self, frame, bbox):
hat_height, hat_width = santa_hat.shape[0], santa_hat.shape[1]

mask_boolean = santa_hat_mask[:, :, 0] == 0
mask_rgb_boolean = numpy.stack([mask_boolean, mask_boolean, mask_boolean], axis=2)
mask_rgb_boolean = np.stack([mask_boolean, mask_boolean, mask_boolean], axis=2)

if x >= 0 and y >= 0:
h = hat_height - max(0, y+hat_height-self.height)
Expand All @@ -212,16 +214,10 @@ def _overlay_obj(self, frame, bbox):
frame[0:0+h, x:x+w, :] = frame[0:0+h, x:x+w, :] * ~mask_rgb_boolean[hat_height-h:hat_height, 0:w, :] + (santa_hat * mask_rgb_boolean)[hat_height-h:hat_height, 0:w, :]

return frame

def santa_hat_plugin(self, frame, bboxes, confs):
# For santa hat plugin, turn Normalize to True in nms function
max_index = max(range(len(confs)), key=confs.__getitem__)
return self._overlay_obj(frame, bboxes[max_index].copy())



def main(
object_detection,
triton_url,
model_name,
stream_ip,
Expand All @@ -231,7 +227,8 @@ def main(
camera_index,
camera_width,
camera_height,
camera_fps
camera_fps,
santa_hat_plugin
):

rtmp_url = "rtmp://{}:{}/{}/{}".format(
Expand Down Expand Up @@ -267,41 +264,33 @@ def main(

p = subprocess.Popen(command, stdin=subprocess.PIPE)

if object_detection:
model = ObjectDetection(
model_name=model_name,
camera_width=camera_width,
camera_height=camera_height,
triton_url=triton_url
)

annotator = Annotator(
model.model.classes,
camera_width,
camera_height
)

period = 10
tracking_index = 0

while camera.isReady():
frame = camera.read()

if tracking_index % period == 0:
bboxes, confs, indexes = model(frame)
tracking_index = 0

if bboxes:
frame = annotator(frame, bboxes, confs, indexes)
# frame = annotator.santa_hat_plugin(frame, bboxes, confs)
tracking_index += 1
model = ObjectDetection(
model_name=model_name,
triton_url=triton_url
)

annotator = Annotator(
model.model.classes,
camera_width,
camera_height,
santa_hat_plugin
)

p.stdin.write(frame.tobytes())
period = 10
tracking_index = 0

while camera.isReady():
frame = camera.read()

if tracking_index % period == 0:
bboxes, confs, indexes = model(frame)
tracking_index = 0

if bboxes:
frame = annotator(frame, bboxes, confs, indexes)
tracking_index += 1

else:
while camera.isReady():
frame = camera.read()
p.stdin.write(frame.tobytes())
p.stdin.write(frame.tobytes())

camera.release()
del camera
Expand All @@ -310,7 +299,6 @@ def main(
if __name__ == "__main__":
load_dotenv()
parser = EnvArgumentParser()
parser.add_arg("OBJECT_DETECTION", default=True, type=bool)
parser.add_arg("TRITON_URL", default="grpc://localhost:8001", type=str)
parser.add_arg("MODEL_NAME", default="yolov5s", type=str)
parser.add_arg("STREAM_IP", default="127.0.0.1", type=str)
Expand All @@ -321,10 +309,10 @@ def main(
parser.add_arg("CAMERA_WIDTH", default=1280, type=int)
parser.add_arg("CAMERA_HEIGHT", default=720, type=int)
parser.add_arg("CAMERA_FPS", default=30, type=int)
parser.add_arg("SANTA_HAT_PLUGIN", default=False, type=bool)
args = parser.parse_args()

main(
args.OBJECT_DETECTION,
args.TRITON_URL,
args.MODEL_NAME,
args.STREAM_IP,
Expand All @@ -334,5 +322,6 @@ def main(
args.CAMERA_INDEX,
args.CAMERA_WIDTH,
args.CAMERA_HEIGHT,
args.CAMERA_FPS
args.CAMERA_FPS,
args.SANTA_HAT_PLUGIN
)
8 changes: 8 additions & 0 deletions app/stream.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
RTMP_URI=location=rtmp://"$STREAM_IP":"$STREAM_PORT"/"$STREAM_APPLICATION"/"$STREAM_KEY"" live=true"
gst-launch-1.0 nvarguscamerasrc sensor-id=$CAMERA_INDEX ! \
'video/x-raw(memory:NVMM)', width=$CAMERA_WIDTH, height=$CAMERA_HEIGHT, framerate=$CAMERA_FPS/1, format=NV12 ! \
nvv4l2h264enc ! \
h264parse ! \
flvmux ! \
rtmpsink "$RTMP_URI"
5 changes: 3 additions & 2 deletions app/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ def add_arg(self, variable, default=None, type=str):
if env is None:
value = default
else:
value = self.cast_type(env, type)
value = self._cast_type(env, type)

self.dict[variable] = value

def cast_type(self, arg, d_type):
@staticmethod
def _cast_type(arg, d_type):
if d_type == list or d_type == tuple or d_type == bool:
try:
cast_value = literal_eval(arg)
Expand Down
5 changes: 1 addition & 4 deletions triton/object_detection/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ model_warmup [{
key: "images"
value: {
data_type: TYPE_FP16
dims: 1
dims: 3
dims: 640
dims: 640
dims: [ 1, 3, 640, 640 ]
input_data_file: "images"
}
}
Expand Down
Loading

0 comments on commit 47958e1

Please sign in to comment.