diff --git a/MedSAM/MedSAMLite/MedSAMLite.py b/MedSAM/MedSAMLite/MedSAMLite.py index c8aecbd..24b6f65 100644 --- a/MedSAM/MedSAMLite/MedSAMLite.py +++ b/MedSAM/MedSAMLite/MedSAMLite.py @@ -32,6 +32,7 @@ try: from numpysocket import NumpySocket + import psutil except: pass # no installation anymore, shorter plugin load @@ -213,6 +214,29 @@ def setup(self) -> None: self.layout.addWidget(uiWidget) self.ui = slicer.util.childWidgetVariables(uiWidget) + ############################################################################ + # Model Selection + if hasattr(self.ui, 'ctkPathModel'): + self.model_path_widget = self.ui.ctkPathModel + # self.ui.clbtnOperation.layout().addWidget(self.ui.lblModelSelection, 0, 0) + # self.ui.clbtnOperation.layout().addWidget(self.ui.ctkPathModel, 0, 1) + else: + import ctk + from PythonQt.QtGui import QLabel + + path_instruction = QLabel('MedSAM Model:') + + self.model_path_widget = ctk.ctkPathLineEdit() + self.model_path_widget.filters = ctk.ctkPathLineEdit.Files + self.model_path_widget.nameFilters = ['*.pth'] + + self.ui.clbtnOperation.layout().addWidget(path_instruction, 0, 0) + self.ui.clbtnOperation.layout().addWidget(self.model_path_widget, 0, 1) + + self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_lite.pth') + self.logic.new_model_loaded = True + ############################################################################ + ############################################################################ # Segmentation Module @@ -222,7 +246,7 @@ def setup(self) -> None: self.selectParameterNode() self.editor.setMRMLScene(slicer.mrmlScene) # print(self.ui.clbtnOperation.layout().__dict__) - self.ui.clbtnOperation.layout().addWidget(self.editor) + self.ui.clbtnOperation.layout().addWidget(self.editor, 3, 0, 1, 2) # self.layout.addWidget(self.editor) # self.editor.currentSegmentIDChanged.connect(print) ############################################################################ @@ -251,6 +275,8 @@ def setup(self) -> None: self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI()) self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D()) + self.model_path_widget.connect('currentPathChanged(const QString&)', lambda: setattr(self.logic, 'new_model_loaded', True)) + # Make sure parameter node is initialized (needed for module reload) self.initializeParameterNode() @@ -419,6 +445,7 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic): progressbar = None server_dir = None widget = None + new_model_loaded = True def __init__(self) -> None: """ @@ -540,15 +567,23 @@ def upgrade(self, download, event): github_base = 'https://raw.githubusercontent.com/bowang-lab/MedSAMSlicer/v%.2f/'%latest_version server_url = github_base + 'server/server.py' module_url = github_base + 'MedSAM/MedSAMLite/MedSAMLite.py' + ui_url = github_base + 'MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui' + + server_file_path = os.path.join(self.server_dir, 'server.py') + module_file_path = __file__ + ui_file_path = os.path.join(os.path.dirname(__file__), 'Resources/UI/MedSAMLite.ui') self.progressbar.setLabelText('Downloading updates...') server_req = requests.get(server_url) module_req = requests.get(module_url) + module_req = requests.get(ui_url) - with open(os.path.join(self.server_dir, 'server.py'), 'w') as server_file: + with open(server_file_path, 'w') as server_file: server_file.write(server_req.text) - with open(__file__, 'w') as module_file: + with open(module_file_path, 'w') as module_file: module_file.write(module_req.text) + with open(ui_file_path, 'w') as ui_file: + ui_file.write(ui_url.text) self.progressbar.setLabelText('Upgraded successfully, please restart Slicer.') else: @@ -560,6 +595,29 @@ def upgrade(self, download, event): event.set() + def check_server(self, serverUrl, max_retries, event = None): + retry_cnt = 0 + while True: + try: + retry_cnt += 1 + response = requests.post(f'{serverUrl}/getServerState') + server_ready = json.loads(response.json()) + if server_ready['ready']: + if event: + event.set() + return True + elif retry_cnt == max_retries: + if event: + event.set() + return False + except Exception as e: + if retry_cnt == max_retries: + if event: + event.set() + return False + time.sleep(1) + + def run_on_background(self, target, args, title): self.progressbar = slicer.util.createProgressDialog(autoClose=False) @@ -576,27 +634,27 @@ def run_on_background(self, target, args, title): self.progressbar.close() - def run_server(self): - print('Running server...') - - # buggy_file_path = os.getcwd() + '/lib/Python/lib/python3.9/site-packages/typing_extensions.py' - - # with open(buggy_file_path, 'r') as file: - # lines = file.readlines() - - # # Update the value in line 173 - # buggy_line_num = 173 - # new_value = '\t\t\tt, (typing._GenericAlias, _types.GenericAlias, _types.UnionType)' + def run_server(self, serverUrl, numpyServerAddress): + print('Terminating possible server duplicates...') + # Terminate image transferrer + try: + with NumpySocket() as s: + s.connect(numpyServerAddress) + s.sendall(np.array([])) + except: + pass - # if 1 <= buggy_line_num <= len(lines): - # lines[buggy_line_num - 1] = f"{new_value}\n" + # Terminate whole server + server_port = int(serverUrl.split(':')[-1]) + try: + server_process = list(filter(lambda proc: proc.laddr.port == server_port and psutil.Process(proc.pid).name() == 'python-real', psutil.net_connections()))[0] + psutil.Process(server_process.pid).terminate() + except: + pass - # # Write the updated content back to the file - # with open(buggy_file_path, 'w') as file: - # file.writelines(lines) - + print('Running server...') - self.server_process = subprocess.Popen(['PythonSlicer', os.path.join(self.server_dir, 'server.py')])#, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, start_new_session=True) + self.server_process = subprocess.Popen(['PythonSlicer', os.path.join(self.server_dir, 'server.py'), self.widget.model_path_widget.currentPath])#, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, start_new_session=True) def cleanup(): timeout_sec = 5 @@ -609,9 +667,7 @@ def cleanup(): self.server_process.kill() # atexit.register(cleanup) - self.server_ready = True - - time.sleep(4) #Change + self.run_on_background(self.check_server, (serverUrl, 10), 'Backend is loading...') def progressCheck(self, serverUrl='http://127.0.0.1:5555'): @@ -628,13 +684,12 @@ def progressCheck(self, serverUrl='http://127.0.0.1:5555'): def captureImage(self): ######## Set your image path here self.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0] - self.img_path = self.volume_node.GetStorageNode().GetFullNameFromFileName() - self.img_sitk = sitk.ReadImage(self.img_path) self.image_data = slicer.util.arrayFromVolume(self.volume_node) ################ Only one node? def sendImage(self, serverUrl='http://127.0.0.1:5555', numpyServerAddress=("127.0.0.1", 5556)): - if not self.server_ready: - self.run_server() + if self.new_model_loaded or not self.check_server(serverUrl, max_retries=1, event=None): + self.run_server(serverUrl, numpyServerAddress) + self.new_model_loaded = False print('sending setImage request...') response = requests.post(f'{serverUrl}/setImage', json={"wmin": -160, "wmax": 240}) # wmin, wmax as input? print('Response from setImage:', response.text) @@ -678,14 +733,8 @@ def inferSegmentation(self, serverUrl='http://127.0.0.1:5555'): return seg_result def showSegmentation(self, segmentation_mask): - segmentation_res_file = os.path.dirname(self.img_path) + '/lite_seg_' + os.path.basename(self.img_path) - seg_sitk = sitk.GetImageFromArray(segmentation_mask) - seg_sitk.CopyInformation(self.img_sitk) - sitk.WriteImage(seg_sitk, segmentation_res_file) ########## Set your segmentation output here - loaded_seg_file = slicer.util.loadSegmentation(segmentation_res_file) - - segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode") - slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(loaded_seg_file, segment_volume, slicer.vtkSegmentation.EXTENT_REFERENCE_GEOMETRY) + segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode", 'segment_'+str(int(time.time()))) + slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask) current_seg_group = self.widget.editor.segmentationNode() if current_seg_group is None: @@ -696,14 +745,17 @@ def showSegmentation(self, segmentation_mask): try: check_if_node_is_removed = slicer.util.getNode(current_seg_group.GetID()) # if scene is closed and reopend, this line will raise an error - slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group) except: self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode") self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) - slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, self.segment_res_group) + current_seg_group = self.segment_res_group + + + current_seg_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node) + slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group) + slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask, current_seg_group, segment_volume.GetName(), self.volume_node) slicer.mrmlScene.RemoveNode(segment_volume) - slicer.mrmlScene.RemoveNode(loaded_seg_file) def applySegmentation(self, serverUrl='http://127.0.0.1:5555'): segmentation_mask = self.inferSegmentation(serverUrl) diff --git a/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui index 631acac..64f366b 100644 --- a/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui +++ b/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui @@ -98,20 +98,39 @@ Start Segmentation - + Send Image - + Segmentation + + + + ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable + + + + *.pth + + + + + + + + MedSAM Model: + + + @@ -124,6 +143,11 @@
ctkCollapsibleButton.h
1 + + ctkPathLineEdit + QWidget +
ctkPathLineEdit.h
+
qMRMLWidget QWidget diff --git a/server/server.py b/server/server.py index d97093b..a5972ef 100644 --- a/server/server.py +++ b/server/server.py @@ -63,8 +63,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, height, width): # settings and app states SAM_MODEL_TYPE = "vit_b" -PARENT_DIR = os.path.dirname(os.path.abspath(__file__)) -MedSAM_CKPT_PATH = os.path.join(PARENT_DIR , "medsam_lite.pth") +MedSAM_CKPT_PATH = sys.argv[1] MEDSAM_IMG_INPUT_SIZE = 1024 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -293,6 +292,10 @@ def set_image(params: ImageParams, background_tasks: BackgroundTasks): def get_progress(): return json.dumps({'layers': image.shape[0], 'generated_embeds': len(embeddings)}) +@app.post("/getServerState") +def get_server_state(): + return json.dumps({'ready': True}) + class InferenceParams(BaseModel): slice_idx: int diff --git a/server_essentials.zip b/server_essentials.zip index 6cb3b8e..82fab91 100644 Binary files a/server_essentials.zip and b/server_essentials.zip differ