diff --git a/spot_rl_experiments/spot_rl/utils/segmentation_service.py b/spot_rl_experiments/spot_rl/utils/segmentation_service.py index 58b913a3..ca168435 100644 --- a/spot_rl_experiments/spot_rl/utils/segmentation_service.py +++ b/spot_rl_experiments/spot_rl/utils/segmentation_service.py @@ -24,7 +24,7 @@ def load_model(model_name="owlvit", device="cpu"): "google/owlv2-base-patch16-ensemble" ).to(device) processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") - return model, processor + # return model, processor if model_name == "sam": print("Loading SAM") # TODO: load sam weights from config @@ -32,7 +32,10 @@ def load_model(model_name="owlvit", device="cpu"): sam = SamPredictor(build_sam(checkpoint=checkpoint_path).to(device)) -def detect(img, text_queries, score_threshold, device, model=None, processor=None): +def detect(img, text_queries, score_threshold, device): + global model + global processor + if model is None or processor is None: load_model("owlvit", device) @@ -103,6 +106,12 @@ def connect_socket(port): return socket +def detect_and_segment_with_socket(rgb_image, port=21001): + global socket + socket = connect_socket(port) if socket is None else socket + socket.send_pyobj(rgb_image) + return socket.recv_pyobj() + def segment_with_socket(rgb_image, bbox, port=21001): global socket socket = connect_socket(port) if socket is None else socket @@ -121,12 +130,19 @@ def segment_with_socket(rgb_image, bbox, port=21001): print(f"Segmentation Server Listening on port {port}") while True: - """A service for running segmentation service, send request using zmq socket""" - img, bbox = socket.recv_pyobj() - print("Recieved img for Segmentation") - masks = segment(img, np.array([bbox]), img.shape[:2], device) - mask = masks[0, 0].cpu().numpy() # hxw, bool - socket.send_pyobj(mask) + """A service for running detection + segmentation service, send request using zmq socket""" + img= socket.recv_pyobj() + print("Recieved img for Detection & Segmentation") + result_labels = detect(img, "table,chair,bed,couch,bottle,can,cup", 0.25, device) + + masks_list = [] + for result_label in result_labels: + bbox = result_label[0] + masks = segment(img, np.array([bbox]), img.shape[:2], device) + mask = masks[0, 0].cpu().numpy() # hxw, bool + masks_list.append(mask) + + socket.send_pyobj(masks_list) # If you want to use detection service, then use the following code to listen to socket # def detect_with_socket(img, object_name, thresh=0.01, device="cuda"):