Skip to content

Commit

Permalink
Fine-tuned models can be loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
rasakereh committed Feb 15, 2024
1 parent 0b6925a commit f634e41
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 43 deletions.
130 changes: 91 additions & 39 deletions MedSAM/MedSAMLite/MedSAMLite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

try:
from numpysocket import NumpySocket
import psutil
except:
pass # no installation anymore, shorter plugin load

Expand Down Expand Up @@ -213,6 +214,29 @@ def setup(self) -> None:
self.layout.addWidget(uiWidget)
self.ui = slicer.util.childWidgetVariables(uiWidget)

############################################################################
# Model Selection
if hasattr(self.ui, 'ctkPathModel'):
self.model_path_widget = self.ui.ctkPathModel
# self.ui.clbtnOperation.layout().addWidget(self.ui.lblModelSelection, 0, 0)
# self.ui.clbtnOperation.layout().addWidget(self.ui.ctkPathModel, 0, 1)
else:
import ctk
from PythonQt.QtGui import QLabel

path_instruction = QLabel('MedSAM Model:')

self.model_path_widget = ctk.ctkPathLineEdit()
self.model_path_widget.filters = ctk.ctkPathLineEdit.Files
self.model_path_widget.nameFilters = ['*.pth']

self.ui.clbtnOperation.layout().addWidget(path_instruction, 0, 0)
self.ui.clbtnOperation.layout().addWidget(self.model_path_widget, 0, 1)

self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_lite.pth')
self.logic.new_model_loaded = True
############################################################################


############################################################################
# Segmentation Module
Expand All @@ -222,7 +246,7 @@ def setup(self) -> None:
self.selectParameterNode()
self.editor.setMRMLScene(slicer.mrmlScene)
# print(self.ui.clbtnOperation.layout().__dict__)
self.ui.clbtnOperation.layout().addWidget(self.editor)
self.ui.clbtnOperation.layout().addWidget(self.editor, 3, 0, 1, 2)
# self.layout.addWidget(self.editor)
# self.editor.currentSegmentIDChanged.connect(print)
############################################################################
Expand Down Expand Up @@ -251,6 +275,8 @@ def setup(self) -> None:
self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI())
self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D())

self.model_path_widget.connect('currentPathChanged(const QString&)', lambda: setattr(self.logic, 'new_model_loaded', True))

# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()

Expand Down Expand Up @@ -419,6 +445,7 @@ class MedSAMLiteLogic(ScriptedLoadableModuleLogic):
progressbar = None
server_dir = None
widget = None
new_model_loaded = True

def __init__(self) -> None:
"""
Expand Down Expand Up @@ -540,15 +567,23 @@ def upgrade(self, download, event):
github_base = 'https://raw.githubusercontent.com/bowang-lab/MedSAMSlicer/v%.2f/'%latest_version
server_url = github_base + 'server/server.py'
module_url = github_base + 'MedSAM/MedSAMLite/MedSAMLite.py'
ui_url = github_base + 'MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui'

server_file_path = os.path.join(self.server_dir, 'server.py')
module_file_path = __file__
ui_file_path = os.path.join(os.path.dirname(__file__), 'Resources/UI/MedSAMLite.ui')

self.progressbar.setLabelText('Downloading updates...')
server_req = requests.get(server_url)
module_req = requests.get(module_url)
module_req = requests.get(ui_url)

with open(os.path.join(self.server_dir, 'server.py'), 'w') as server_file:
with open(server_file_path, 'w') as server_file:
server_file.write(server_req.text)
with open(__file__, 'w') as module_file:
with open(module_file_path, 'w') as module_file:
module_file.write(module_req.text)
with open(ui_file_path, 'w') as ui_file:
ui_file.write(ui_url.text)
self.progressbar.setLabelText('Upgraded successfully, please restart Slicer.')

else:
Expand All @@ -560,6 +595,29 @@ def upgrade(self, download, event):

event.set()

def check_server(self, serverUrl, max_retries, event = None):
retry_cnt = 0
while True:
try:
retry_cnt += 1
response = requests.post(f'{serverUrl}/getServerState')
server_ready = json.loads(response.json())
if server_ready['ready']:
if event:
event.set()
return True
elif retry_cnt == max_retries:
if event:
event.set()
return False
except Exception as e:
if retry_cnt == max_retries:
if event:
event.set()
return False
time.sleep(1)



def run_on_background(self, target, args, title):
self.progressbar = slicer.util.createProgressDialog(autoClose=False)
Expand All @@ -576,27 +634,27 @@ def run_on_background(self, target, args, title):

self.progressbar.close()

def run_server(self):
print('Running server...')

# buggy_file_path = os.getcwd() + '/lib/Python/lib/python3.9/site-packages/typing_extensions.py'

# with open(buggy_file_path, 'r') as file:
# lines = file.readlines()

# # Update the value in line 173
# buggy_line_num = 173
# new_value = '\t\t\tt, (typing._GenericAlias, _types.GenericAlias, _types.UnionType)'
def run_server(self, serverUrl, numpyServerAddress):
print('Terminating possible server duplicates...')
# Terminate image transferrer
try:
with NumpySocket() as s:
s.connect(numpyServerAddress)
s.sendall(np.array([]))
except:
pass

# if 1 <= buggy_line_num <= len(lines):
# lines[buggy_line_num - 1] = f"{new_value}\n"
# Terminate whole server
server_port = int(serverUrl.split(':')[-1])
try:
server_process = list(filter(lambda proc: proc.laddr.port == server_port and psutil.Process(proc.pid).name() == 'python-real', psutil.net_connections()))[0]
psutil.Process(server_process.pid).terminate()
except:
pass

# # Write the updated content back to the file
# with open(buggy_file_path, 'w') as file:
# file.writelines(lines)

print('Running server...')

self.server_process = subprocess.Popen(['PythonSlicer', os.path.join(self.server_dir, 'server.py')])#, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, start_new_session=True)
self.server_process = subprocess.Popen(['PythonSlicer', os.path.join(self.server_dir, 'server.py'), self.widget.model_path_widget.currentPath])#, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, start_new_session=True)
def cleanup():
timeout_sec = 5

Expand All @@ -609,9 +667,7 @@ def cleanup():
self.server_process.kill()
# atexit.register(cleanup)

self.server_ready = True

time.sleep(4) #Change
self.run_on_background(self.check_server, (serverUrl, 10), 'Backend is loading...')


def progressCheck(self, serverUrl='http://127.0.0.1:5555'):
Expand All @@ -628,13 +684,12 @@ def progressCheck(self, serverUrl='http://127.0.0.1:5555'):
def captureImage(self):
######## Set your image path here
self.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0]
self.img_path = self.volume_node.GetStorageNode().GetFullNameFromFileName()
self.img_sitk = sitk.ReadImage(self.img_path)
self.image_data = slicer.util.arrayFromVolume(self.volume_node) ################ Only one node?

def sendImage(self, serverUrl='http://127.0.0.1:5555', numpyServerAddress=("127.0.0.1", 5556)):
if not self.server_ready:
self.run_server()
if self.new_model_loaded or not self.check_server(serverUrl, max_retries=1, event=None):
self.run_server(serverUrl, numpyServerAddress)
self.new_model_loaded = False
print('sending setImage request...')
response = requests.post(f'{serverUrl}/setImage', json={"wmin": -160, "wmax": 240}) # wmin, wmax as input?
print('Response from setImage:', response.text)
Expand Down Expand Up @@ -678,14 +733,8 @@ def inferSegmentation(self, serverUrl='http://127.0.0.1:5555'):
return seg_result

def showSegmentation(self, segmentation_mask):
segmentation_res_file = os.path.dirname(self.img_path) + '/lite_seg_' + os.path.basename(self.img_path)
seg_sitk = sitk.GetImageFromArray(segmentation_mask)
seg_sitk.CopyInformation(self.img_sitk)
sitk.WriteImage(seg_sitk, segmentation_res_file) ########## Set your segmentation output here
loaded_seg_file = slicer.util.loadSegmentation(segmentation_res_file)

segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode")
slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(loaded_seg_file, segment_volume, slicer.vtkSegmentation.EXTENT_REFERENCE_GEOMETRY)
segment_volume = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode", 'segment_'+str(int(time.time())))
slicer.util.updateVolumeFromArray(segment_volume, segmentation_mask)

current_seg_group = self.widget.editor.segmentationNode()
if current_seg_group is None:
Expand All @@ -696,14 +745,17 @@ def showSegmentation(self, segmentation_mask):

try:
check_if_node_is_removed = slicer.util.getNode(current_seg_group.GetID()) # if scene is closed and reopend, this line will raise an error
slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group)
except:
self.segment_res_group = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
self.segment_res_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, self.segment_res_group)
current_seg_group = self.segment_res_group


current_seg_group.SetReferenceImageGeometryParameterFromVolumeNode(self.volume_node)
slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(segment_volume, current_seg_group)
slicer.util.updateSegmentBinaryLabelmapFromArray(segmentation_mask, current_seg_group, segment_volume.GetName(), self.volume_node)

slicer.mrmlScene.RemoveNode(segment_volume)
slicer.mrmlScene.RemoveNode(loaded_seg_file)

def applySegmentation(self, serverUrl='http://127.0.0.1:5555'):
segmentation_mask = self.inferSegmentation(serverUrl)
Expand Down
28 changes: 26 additions & 2 deletions MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,39 @@
<string>Start Segmentation</string>
</property>
<layout class="QGridLayout" name="gridLayout_4">
<item row="0" column="0">
<item row="2" column="0">
<widget class="QPushButton" name="pbSendImage">
<property name="text">
<string>Send Image</string>
</property>
</widget>
</item>
<item row="0" column="1">
<item row="2" column="1">
<widget class="QPushButton" name="pbSegment">
<property name="text">
<string>Segmentation</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="ctkPathLineEdit" name="ctkPathModel">
<property name="filters">
<set>ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable</set>
</property>
<property name="nameFilters">
<stringlist>
<string>*.pth</string>
</stringlist>
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QLabel" name="lblModelSelection">
<property name="text">
<string>MedSAM Model:</string>
</property>
</widget>
</item>
</layout>
</widget>
</item>
Expand All @@ -124,6 +143,11 @@
<header>ctkCollapsibleButton.h</header>
<container>1</container>
</customwidget>
<customwidget>
<class>ctkPathLineEdit</class>
<extends>QWidget</extends>
<header>ctkPathLineEdit.h</header>
</customwidget>
<customwidget>
<class>qMRMLWidget</class>
<extends>QWidget</extends>
Expand Down
7 changes: 5 additions & 2 deletions server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, height, width):

# settings and app states
SAM_MODEL_TYPE = "vit_b"
PARENT_DIR = os.path.dirname(os.path.abspath(__file__))
MedSAM_CKPT_PATH = os.path.join(PARENT_DIR , "medsam_lite.pth")
MedSAM_CKPT_PATH = sys.argv[1]
MEDSAM_IMG_INPUT_SIZE = 1024
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -293,6 +292,10 @@ def set_image(params: ImageParams, background_tasks: BackgroundTasks):
def get_progress():
return json.dumps({'layers': image.shape[0], 'generated_embeds': len(embeddings)})

@app.post("/getServerState")
def get_server_state():
return json.dumps({'ready': True})


class InferenceParams(BaseModel):
slice_idx: int
Expand Down
Binary file modified server_essentials.zip
Binary file not shown.

0 comments on commit f634e41

Please sign in to comment.