diff --git a/MedSAM/MedSAMLite/MedSAMLite.py b/MedSAM/MedSAMLite/MedSAMLite.py
index c8aecbd..24b6f65 100644
--- a/MedSAM/MedSAMLite/MedSAMLite.py
+++ b/MedSAM/MedSAMLite/MedSAMLite.py
@@ -32,6 +32,7 @@
try:
from numpysocket import NumpySocket
+ import psutil
except:
pass # no installation anymore, shorter plugin load
@@ -213,6 +214,29 @@ def setup(self) -> None:
self.layout.addWidget(uiWidget)
self.ui = slicer.util.childWidgetVariables(uiWidget)
+ ############################################################################
+ # Model Selection
+ if hasattr(self.ui, 'ctkPathModel'):
+ self.model_path_widget = self.ui.ctkPathModel
+ # self.ui.clbtnOperation.layout().addWidget(self.ui.lblModelSelection, 0, 0)
+ # self.ui.clbtnOperation.layout().addWidget(self.ui.ctkPathModel, 0, 1)
+ else:
+ import ctk
+ from PythonQt.QtGui import QLabel
+
+ path_instruction = QLabel('MedSAM Model:')
+
+ self.model_path_widget = ctk.ctkPathLineEdit()
+ self.model_path_widget.filters = ctk.ctkPathLineEdit.Files
+ self.model_path_widget.nameFilters = ['*.pth']
+
+ self.ui.clbtnOperation.layout().addWidget(path_instruction, 0, 0)
+ self.ui.clbtnOperation.layout().addWidget(self.model_path_widget, 0, 1)
+
+ self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_lite.pth')
+ self.logic.new_model_loaded = True
+ ############################################################################
+
############################################################################
# Segmentation Module
@@ -222,7 +246,7 @@ def setup(self) -> None:
self.selectParameterNode()
self.editor.setMRMLScene(slicer.mrmlScene)
# print(self.ui.clbtnOperation.layout().__dict__)
- self.ui.clbtnOperation.layout().addWidget(self.editor)
+ self.ui.clbtnOperation.layout().addWidget(self.editor, 3, 0, 1, 2)
# self.layout.addWidget(self.editor)
# self.editor.currentSegmentIDChanged.connect(print)
############################################################################
@@ -251,6 +275,8 @@ def setup(self) -> None:
self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI())
self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D())
+ self.model_path_widget.connect('currentPathChanged(const QString&)', lambda: setattr(self.logic, 'new_model_loaded', True))
+
# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()
@@ -419,6 +445,7 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic):
progressbar = None
server_dir = None
widget = None
+ new_model_loaded = True
def __init__(self) -> None:
"""
@@ -540,15 +567,23 @@ def upgrade(self, download, event):
github_base = 'https://raw.githubusercontent.com/bowang-lab/MedSAMSlicer/v%.2f/'%latest_version
server_url = github_base + 'server/server.py'
module_url = github_base + 'MedSAM/MedSAMLite/MedSAMLite.py'
+ ui_url = github_base + 'MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui'
+
+ server_file_path = os.path.join(self.server_dir, 'server.py')
+ module_file_path = __file__
+ ui_file_path = os.path.join(os.path.dirname(__file__), 'Resources/UI/MedSAMLite.ui')
self.progressbar.setLabelText('Downloading updates...')
server_req = requests.get(server_url)
module_req = requests.get(module_url)
+ module_req = requests.get(ui_url)
- with open(os.path.join(self.server_dir, 'server.py'), 'w') as server_file:
+ with open(server_file_path, 'w') as server_file:
server_file.write(server_req.text)
- with open(__file__, 'w') as module_file:
+ with open(module_file_path, 'w') as module_file:
module_file.write(module_req.text)
+ with open(ui_file_path, 'w') as ui_file:
+ ui_file.write(ui_url.text)
self.progressbar.setLabelText('Upgraded successfully, please restart Slicer.')
else:
@@ -560,6 +595,29 @@ def upgrade(self, download, event):
event.set()
+ def check_server(self, serverUrl, max_retries, event = None):
+ retry_cnt = 0
+ while True:
+ try:
+ retry_cnt += 1
+ response = requests.post(f'{serverUrl}/getServerState')
+ server_ready = json.loads(response.json())
+ if server_ready['ready']:
+ if event:
+ event.set()
+ return True
+ elif retry_cnt == max_retries:
+ if event:
+ event.set()
+ return False
+ except Exception as e:
+ if retry_cnt == max_retries:
+ if event:
+ event.set()
+ return False
+ time.sleep(1)
+
+
def run_on_background(self, target, args, title):
self.progressbar = slicer.util.createProgressDialog(autoClose=False)
@@ -576,27 +634,27 @@ def run_on_background(self, target, args, title):
self.progressbar.close()
- def run_server(self):
- print('Running server...')
-
- # buggy_file_path = os.getcwd() + '/lib/Python/lib/python3.9/site-packages/typing_extensions.py'
-
- # with open(buggy_file_path, 'r') as file:
- # lines = file.readlines()
-
- # # Update the value in line 173
- # buggy_line_num = 173
- # new_value = '\t\t\tt, (typing._GenericAlias, _types.GenericAlias, _types.UnionType)'
+ def run_server(self, serverUrl, numpyServerAddress):
+ print('Terminating possible server duplicates...')
+ # Terminate image transferrer
+ try:
+ with NumpySocket() as s:
+ s.connect(numpyServerAddress)
+ s.sendall(np.array([]))
+ except:
+ pass
- # if 1 <= buggy_line_num <= len(lines):
- # lines[buggy_line_num - 1] = f"{new_value}\n"
+ # Terminate whole server
+ server_port = int(serverUrl.split(':')[-1])
+ try:
+ server_process = list(filter(lambda proc: proc.laddr.port == server_port and psutil.Process(proc.pid).name() == 'python-real', psutil.net_connections()))[0]
+ psutil.Process(server_process.pid).terminate()
+ except:
+ pass
- # # Write the updated content back to the file
- # with open(buggy_file_path, 'w') as file:
- # file.writelines(lines)
-
+ print('Running server...')
- self.server_process = subprocess.Popen(['PythonSlicer', os.path.join(self.server_dir, 'server.py')])#, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, start_new_session=True)
+ self.server_process = subprocess.Popen(['PythonSlicer', os.path.join(self.server_dir, 'server.py'), self.widget.model_path_widget.currentPath])#, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, start_new_session=True)
def cleanup():
timeout_sec = 5
@@ -609,9 +667,7 @@ def cleanup():
self.server_process.kill()
# atexit.register(cleanup)
- self.server_ready = True
-
- time.sleep(4) #Change
+ self.run_on_background(self.check_server, (serverUrl, 10), 'Backend is loading...')
def progressCheck(self, serverUrl='http://127.0.0.1:5555'):
@@ -628,13 +684,12 @@ def progressCheck(self, serverUrl='http://127.0.0.1:5555'):
def captureImage(self):
######## Set your image path here
self.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0]
- self.img_path = self.volume_node.GetStorageNode().GetFullNameFromFileName()
- self.img_sitk = sitk.ReadImage(self.img_path)
self.image_data = slicer.util.arrayFromVolume(self.volume_node) ################ Only one node?
def sendImage(self, serverUrl='http://127.0.0.1:5555', numpyServerAddress=("127.0.0.1", 5556)):
- if not self.server_ready:
- self.run_server()
+ if self.new_model_loaded or not self.check_server(serverUrl, max_retries=1, event=None):
+ self.run_server(serverUrl, numpyServerAddress)
+ self.new_model_loaded = False
print('sending setImage request...')
response = requests.post(f'{serverUrl}/setImage', json={"wmin": -160, "wmax": 240}) # wmin, wmax as input?
print('Response from setImage:', response.text)
@@ -678,14 +733,8 @@ def inferSegmentation(self, serverUrl='http://127.0.0.1:5555'):
return seg_result
def showSegmentation(self, segmentation_mask):
- segmentation_res_file = os.path.dirname(self.img_path) + '/lite_seg_' + os.path.basename(self.img_path)
- seg_sitk = sitk.GetImageFromArray(segmentation_mask)
- seg_sitk.CopyInformation(self.img_sitk)
- sitk.WriteImage(seg_sitk, segmentation_res_file) ########## Set your segmentation output here
- loaded_seg_file = slicer.util.loadSegmentation(segmentation_res_file)
-
- segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode")
- slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(loaded_seg_file, segment_volume, slicer.vtkSegmentation.EXTENT_REFERENCE_GEOMETRY)
+ segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode", 'segment_'+str(int(time.time())))
+ slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask)
current_seg_group = self.widget.editor.segmentationNode()
if current_seg_group is None:
@@ -696,14 +745,17 @@ def showSegmentation(self, segmentation_mask):
try:
check_if_node_is_removed = slicer.util.getNode(current_seg_group.GetID()) # if scene is closed and reopend, this line will raise an error
- slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group)
except:
self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
- slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, self.segment_res_group)
+ current_seg_group = self.segment_res_group
+
+
+ current_seg_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
+ slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group)
+ slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask, current_seg_group, segment_volume.GetName(), self.volume_node)
slicer.mrmlScene.RemoveNode(segment_volume)
- slicer.mrmlScene.RemoveNode(loaded_seg_file)
def applySegmentation(self, serverUrl='http://127.0.0.1:5555'):
segmentation_mask = self.inferSegmentation(serverUrl)
diff --git a/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui
index 631acac..64f366b 100644
--- a/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui
+++ b/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui
@@ -98,20 +98,39 @@
Start Segmentation
- -
+
-
Send Image
- -
+
-
Segmentation
+ -
+
+
+ ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable
+
+
+
+ *.pth
+
+
+
+
+ -
+
+
+ MedSAM Model:
+
+
+
@@ -124,6 +143,11 @@
1
+
+ ctkPathLineEdit
+ QWidget
+
+
qMRMLWidget
QWidget
diff --git a/server/server.py b/server/server.py
index d97093b..a5972ef 100644
--- a/server/server.py
+++ b/server/server.py
@@ -63,8 +63,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, height, width):
# settings and app states
SAM_MODEL_TYPE = "vit_b"
-PARENT_DIR = os.path.dirname(os.path.abspath(__file__))
-MedSAM_CKPT_PATH = os.path.join(PARENT_DIR , "medsam_lite.pth")
+MedSAM_CKPT_PATH = sys.argv[1]
MEDSAM_IMG_INPUT_SIZE = 1024
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -293,6 +292,10 @@ def set_image(params: ImageParams, background_tasks: BackgroundTasks):
def get_progress():
return json.dumps({'layers': image.shape[0], 'generated_embeds': len(embeddings)})
+@app.post("/getServerState")
+def get_server_state():
+ return json.dumps({'ready': True})
+
class InferenceParams(BaseModel):
slice_idx: int
diff --git a/server_essentials.zip b/server_essentials.zip
index 6cb3b8e..82fab91 100644
Binary files a/server_essentials.zip and b/server_essentials.zip differ