Skip to content

Commit

Permalink
Merge pull request #42 from bowang-lab/Road_To_Extension_Index
Browse files Browse the repository at this point in the history
Faster and more diverse segmentation
  • Loading branch information
rasakereh authored Sep 25, 2024
2 parents 3f0abf6 + 0a1c6d6 commit bcd8ef7
Show file tree
Hide file tree
Showing 11 changed files with 5,333 additions and 304 deletions.
143 changes: 135 additions & 8 deletions MedSAM/MedSAMLite/MedSAMLite.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@

from slicer import vtkMRMLScalarVolumeNode

from PythonQt.QtCore import QTimer, QByteArray
from PythonQt.QtCore import QTimer, QByteArray, Qt
from PythonQt.QtGui import QIcon, QPixmap, QMessageBox

try:
import gdown
from medsam_interface import MedSAM_Interface # FIXME
except:
pass # no installation anymore, shorter plugin load

MEDSAMLITE_VERSION = 'v0.1'
MEDSAMLITE_VERSION = 'v0.11'

#
# MedSAMLite
Expand Down Expand Up @@ -197,7 +198,7 @@ def setup(self) -> None:
############################################################################
# Model Selection
self.model_path_widget = self.ui.ctkPathModel
self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_lite.pth')
self.model_path_widget.currentPath = os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth')
self.logic.new_model_loaded = True
############################################################################

Expand All @@ -209,7 +210,7 @@ def setup(self) -> None:
self.editor.setMaximumNumberOfUndoStates(10)
self.selectParameterNode()
self.editor.setMRMLScene(slicer.mrmlScene)
self.ui.clbtnOperation.layout().addWidget(self.editor, 3, 0, 1, 2)
self.ui.clbtnOperation.layout().addWidget(self.editor, 5, 0, 1, 2)
############################################################################

###########################################################################
Expand Down Expand Up @@ -251,6 +252,66 @@ def setup(self) -> None:
self.ui.widgetROI.findChild("QCheckBox", "insideOutCheckBox").hide()
self.ui.widgetROI.findChild("QLabel", "label_10").hide()
self.ui.widgetROI.findChild("QComboBox", "roiTypeComboBox").hide()

# Segmentation Engine
self.engine_list = [
{
'name': 'Classic MedSAM',
'description': 'Classic MedSAM engine uses PyTorch. It supports GPU and approximate segmentation calculation for faster results.',
'default checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/classic/medsam_lite.pth'),
'controls to hide': [self.ui.lblSubModel, self.ui.cmbSubModel],
'controls to show': [self.ui.cmbSpeed],
'url': 'https://drive.google.com/drive/folders/1cSLWY_kwiV3JXRNJktZSbhUwMxZ5eV-q?usp=sharing',
'submodels': {}
},
{
'name': 'OpenVino MedSAM',
'description': 'OpenVino MedSAM is faster than Classic MedSAM on CPU as it uses OpenVINO. Approximate segmentation calculation for faster results are supported. No GPU support.',
'default checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/openvino/medsam_lite_image_encoder.xml'),
'controls to hide': [self.ui.lblSubModel, self.ui.cmbSubModel],
'controls to show': [self.ui.cmbSpeed],
'url': 'https://drive.google.com/drive/folders/1FTwy6uOUFIrWnrkBbTNufv8N9r34hmeG?usp=sharing',
'submodels': {}
},
{
'name': 'DAFT MedSAM',
'description': 'DAFT MedSAM is one of the fastest engines as it uses a relatively smaller data-specific model and OpenVINO backend. No approximate segmentation nor GPU support and need for user\'s mindful model selection are the cons.',
'default checkpoint': '',
'controls to hide': [self.ui.cmbSpeed],
'controls to show': [self.ui.lblSubModel, self.ui.cmbSubModel],
'url': '',
'submodels': {
'3D (CT, MR, PTE)': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/3D/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1jR7Qz-RSm-uDaZzpOxFI4wBFeCxeSFEb?usp=drive_link'},
'Dermoscopy': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/Dermoscopy/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1Zwwp0kScYJsLB1exs63B_HODT-9csj_g?usp=drive_link'},
'Endoscopy': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/Endoscopy/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1-QrmdwEUYEZsrXlEilhrovP-299li1J-?usp=drive_link'},
'Fundus': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/Fundus/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1I2QESz1VXcKDrKDg-Er44PqwS00vtSMc?usp=drive_link'},
'general': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/general/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1ojqPtCYwh-bzPdgA7GS0Zt78AO3cjde5?usp=drive_link'},
'Mammography': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/Mammography/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1kS0s7fcIlXE-hS0-sWDNWXWHQ9hxqHhh?usp=drive_link'},
'Microscopy': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/Microscopy/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1p788QfFuLZW2XBKyjpbS9leoWOJhn5wg?usp=drive_link'},
'OCT': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/OCT/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1RcX686vYU-jHWwi8NZ9JWKSc61vmF_XB?usp=drive_link'},
'US': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/US/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/1dWifPYpA168KbUoKF5XWBnCBaU4fzeMm?usp=drive_link'},
'XRay': {'checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/daft/XRay/image_encoder.xml'), 'url': 'https://drive.google.com/drive/folders/120gqhi-psC0c1W-D18iXiya9zuH2a9nX?usp=drive_link'},
}
},
# {
# 'name': 'Medficient SAM',
# 'description': 'Medficient SAM [.... placeholder ....]',
# 'default checkpoint': os.path.join(self.logic.server_dir, 'medsam_interface/models/medficient/model.pth'),
# 'controls to hide': [self.ui.cmbSpeed, self.ui.lblSubModel, self.ui.cmbSubModel],
# 'controls to show': [],
# 'url': '',
# 'submodels': {}
# },
]

self.ui.cmbEngine.addItems(list(map(lambda x: x['name'], self.engine_list)))
for i, engine in enumerate(self.engine_list):
self.ui.cmbEngine.setItemData(i, engine['description'], Qt.ToolTipRole)
self.ui.cmbEngine.currentTextChanged.connect(self.newEngineSelected)
self.ui.cmbSubModel.currentTextChanged.connect(self.newSubmodelSelected)

# Segmentation Speed
self.ui.cmbSpeed.addItems(['Normal Speed - Highest Quality', 'Faster Segmentation - High Quality', 'Fastest Segmentation - Moderate Quality'])

self.ui.pbAttach.connect('clicked(bool)', lambda: self._createAndAttachROI())
self.ui.pbTwoDim.connect('clicked(bool)', lambda: self.makeROI2D())
Expand All @@ -261,13 +322,53 @@ def setup(self) -> None:

# Make sure parameter node is initialized (needed for module reload)
self.initializeParameterNode()
self.newEngineSelected('Classic MedSAM')

def setManualPreprocessVis(self, visible):
self.ui.lblLevel.setVisible(visible)
self.ui.lblWidth.setVisible(visible)
self.ui.sldWinLevel.setVisible(visible)
self.ui.sldWinWidth.setVisible(visible)

def newEngineSelected(self, new_engine):
current_engine = list(filter(lambda x: x['name'] == new_engine, self.engine_list))[0]
# inform logic object
self.logic.new_model_loaded = True
# load list of submodels
self.dont_invoke_submodel_change = True # prevent onchange event to happen
self.ui.cmbSubModel.clear()
self.ui.cmbSubModel.addItems(list(current_engine['submodels'].keys()))
self.dont_invoke_submodel_change = False
# show/hide engine specific options
for ctrl in current_engine['controls to show']:
ctrl.setVisible(True)
for ctrl in current_engine['controls to hide']:
ctrl.setVisible(False)

# change engine-related paths
self.model_path_widget.currentPath = current_engine['default checkpoint']
self.updateAllParameters()

# if there is a submodel, choose the first one
if len(current_engine['submodels']) > 0:
self.newSubmodelSelected(list(current_engine['submodels'].keys())[0])

# download checkpoints if necessary
if len(current_engine['submodels']) == 0:
self.logic.download_if_necessary(current_engine['url'], current_engine['default checkpoint'])

def newSubmodelSelected(self, new_submodel):
if self.dont_invoke_submodel_change: return
current_submodel = list(filter(lambda x: x['name'] == self.ui.cmbEngine.currentText, self.engine_list))[0]['submodels'][new_submodel]
# inform logic object
self.logic.new_model_loaded = True

# change submodel-related paths
self.model_path_widget.currentPath = current_submodel['checkpoint']
self.updateAllParameters()

# download checkpoints if necessary
self.logic.download_if_necessary(current_submodel['url'], current_submodel['checkpoint'])

def selectParameterNode(self):
# Select parameter set node if one is found in the scene, and create one otherwise
Expand Down Expand Up @@ -410,7 +511,10 @@ def setParameterNode(self, inputParameterNode: Optional[MedSAMLiteParameterNode]
if self._parameterNode:
# Note: in the .ui file, a Qt dynamic property called "SlicerParameterName" is set on each
# ui element that needs connection.
self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui)
try:
self._parameterNodeGuiTag = self._parameterNode.connectGui(self.ui)
except:
pass #this part might be invoked before UI is loaded. does not cause any issues but might be confusing for users
self.renderAllParameters()


Expand Down Expand Up @@ -474,6 +578,10 @@ def __init__(self) -> None:
Called when the logic class is instantiated. Can be used for initializing member variables.
"""
ScriptedLoadableModuleLogic.__init__(self)
try: # In case the dependencies are not installed, an error will raise
self.backend = MedSAM_Interface()
except:
pass

def getParameterNode(self):
return MedSAMLiteParameterNode(super().getParameterNode())
Expand Down Expand Up @@ -526,6 +634,10 @@ def download_wrapper(self, url, filename, download_needed, event):

def install_dependencies(self):
dependencies = {
'ONNX': 'onnx>=1.16.2',
'ONNX Runtime': 'onnxruntime>=1.19.2',
'Google Drive Downloader': 'gdown>=5.2.0',
'OpenVINO': 'openvino-dev',
'PyTorch': 'torch==2.0.1 torchvision==0.15.2',
'MedSam Lite Server': '-e "%s"'%(self.server_dir)
}
Expand Down Expand Up @@ -600,9 +712,21 @@ def run_on_background(self, target, args, title, progress_check=None):

self.progressbar.close()

def download_model(self, url, model_path, event):
gdown.download_folder(url=url, output=model_path)
event.set()

def download_if_necessary(self, model_url, model_path):
model_path = os.path.dirname(model_path)
if not os.path.isdir(model_path):
continueDownload = QMessageBox.question(None,'', "You need to download extra model files for this engine. Do you want to continue? (downloading %s to %s)"%(model_url, model_path), QMessageBox.Yes | QMessageBox.No)
if continueDownload == QMessageBox.No: return
self.run_on_background(self.download_model, (model_url, model_path), "Downloading model files...")

def run_server(self):
#FIXME show that 'Backend is loading...'
self.backend = MedSAM_Interface()
self.widget.updateAllParameters()
self.backend.set_engine(self.widget.ui.cmbEngine.currentText)
self.widget.renderAllParameters()
self.backend.MedSAM_CKPT_PATH = self.widget.model_path_widget.currentPath
self.backend.load_model()
Expand All @@ -613,7 +737,7 @@ def progressCheck(self, partial=False):
progress_data = self.backend.get_progress()
self.progressbar.value = progress_data['generated_embeds']

if progress_data['layers'] == progress_data['generated_embeds']:
if progress_data['layers'] <= progress_data['generated_embeds']:
self.progressbar.close()
self.timer.stop()
self.widget.ui.pbSegment.setEnabled(True)
Expand All @@ -632,6 +756,8 @@ def captureImage(self):
######## Set your image path here
self.volume_node = slicer.util.getNodesByClass('vtkMRMLScalarVolumeNode')[0]
self.image_data = slicer.util.arrayFromVolume(self.volume_node) ################ Only one node?
if len(self.image_data.shape) == 4 and self.image_data.shape[-1] == 4: # colored image, it can have 4 channels (r,g,b,a) so we remove the last one
self.image_data = self.image_data[:,:,:,:3]

def sendImage(self, partial=False):
self.widget.ui.pbSegment.setEnabled(False)
Expand Down Expand Up @@ -659,6 +785,7 @@ def sendImage(self, partial=False):
self.timer.timeout.connect(lambda: self.progressCheck(partial))
self.timer.start(1000)

self.backend.speed_level = 1 if 'Normal' in self.widget.ui.cmbSpeed.currentText else 2 if 'Faster' in self.widget.ui.cmbSpeed.currentText else 3
self.backend.set_image(self.image_data, -160, 240, zmin, zmax, recurrent_func=slicer.app.processEvents)

self.widget.updateAllParameters()
Expand All @@ -680,7 +807,7 @@ def inferSegmentation(self):
slice_idx, bbox, zrange = self.get_bounding_box()
seg_data = self.backend.infer(slice_idx, bbox, zrange)
frames = list(seg_data.keys())
seg_result = np.zeros_like(self.image_data)
seg_result = np.zeros(self.image_data.shape[:3])
for frame in frames:
seg_result[frame, :, :] = seg_data[frame]

Expand Down
50 changes: 38 additions & 12 deletions MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@
<string>Select the Region of Interest</string>
</property>
<layout class="QGridLayout" name="gridLayout_3">
<item row="2" column="0" colspan="2">
<widget class="qMRMLMarkupsROIWidget" name="widgetROI"/>
</item>
<item row="0" column="0">
<widget class="QPushButton" name="pbAttach">
<property name="text">
Expand Down Expand Up @@ -124,6 +121,9 @@
</property>
</widget>
</item>
<item row="2" column="0" colspan="2">
<widget class="qMRMLMarkupsROIWidget" name="widgetROI"/>
</item>
</layout>
</widget>
</item>
Expand All @@ -134,38 +134,64 @@
</property>
<layout class="QGridLayout" name="gridLayout_4">
<item row="2" column="0">
<widget class="QPushButton" name="pbSendImage">
<widget class="QLabel" name="lblModelSelection">
<property name="text">
<string>Send Image</string>
<string>MedSAM Model:</string>
</property>
</widget>
</item>
<item row="2" column="1">
<widget class="QPushButton" name="pbSegment">
<item row="1" column="0">
<widget class="QLabel" name="lblSubModel">
<property name="text">
<string>Segmentation</string>
<string>Submodel:</string>
</property>
</widget>
</item>
<item row="1" column="1">
<item row="2" column="1">
<widget class="ctkPathLineEdit" name="ctkPathModel">
<property name="filters">
<set>ctkPathLineEdit::Executable|ctkPathLineEdit::Files|ctkPathLineEdit::NoDot|ctkPathLineEdit::NoDotDot|ctkPathLineEdit::Readable</set>
</property>
<property name="nameFilters">
<stringlist>
<string>*.pth</string>
<string>*.xml</string>
<string>*.onnx</string>
<string>*.ckpt</string>
</stringlist>
</property>
</widget>
</item>
<item row="1" column="0">
<widget class="QLabel" name="lblModelSelection">
<item row="0" column="1">
<widget class="QComboBox" name="cmbEngine"/>
</item>
<item row="3" column="0" colspan="2">
<widget class="QComboBox" name="cmbSpeed"/>
</item>
<item row="4" column="0">
<widget class="QPushButton" name="pbSendImage">
<property name="text">
<string>MedSAM Model:</string>
<string>Send Image</string>
</property>
</widget>
</item>
<item row="0" column="0">
<widget class="QLabel" name="lblEngineSelection">
<property name="text">
<string>Select Engine:</string>
</property>
</widget>
</item>
<item row="4" column="1">
<widget class="QPushButton" name="pbSegment">
<property name="text">
<string>Segmentation</string>
</property>
</widget>
</item>
<item row="1" column="1">
<widget class="QComboBox" name="cmbSubModel"/>
</item>
</layout>
</widget>
</item>
Expand Down
Loading

0 comments on commit bcd8ef7

Please sign in to comment.