Skip to content

Commit

Permalink
Update added, fixed non-square image issue
Browse files Browse the repository at this point in the history
  • Loading branch information
rasakereh committed Feb 2, 2024
1 parent e28f355 commit e7048eb
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 31 deletions.
110 changes: 88 additions & 22 deletions MedSAM/MedSAMLite/MedSAMLite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
except:
pass # no installation anymore, shorter plugin load

MEDSAMLITE_VERSION = 'v0.03'

#
# MedSAMLite
#
Expand Down Expand Up @@ -181,18 +183,22 @@ def setup(self) -> None:
DEPENDENCIES_AVAILABLE = False

if not DEPENDENCIES_AVAILABLE:
from PythonQt.QtGui import QLabel, QPushButton, QSpacerItem, QSizePolicy
from PythonQt.QtGui import QLabel, QPushButton, QSpacerItem, QSizePolicy, QCheckBox
import ctk
path_instruction = QLabel('Choose a folder to install module dependencies in')
restart_instruction = QLabel('Restart 3D Slicer after all dependencies are installed!')

ctk_install_path = ctk.ctkPathLineEdit()
ctk_install_path.filters = ctk.ctkPathLineEdit.Dirs

local_install = QCheckBox("Install from local server_essentials.zip")
local_install.toggled.connect(lambda:self.toggleLocalInstall(local_install, ctk_install_path))

install_btn = QPushButton('Install dependencies')
install_btn.clicked.connect(lambda: self.logic.install_dependencies(ctk_install_path))

self.layout.addWidget(path_instruction)
self.layout.addWidget(local_install)
self.layout.addWidget(ctk_install_path)
self.layout.addWidget(install_btn)
self.layout.addWidget(restart_instruction)
Expand Down Expand Up @@ -233,6 +239,7 @@ def setup(self) -> None:
self.addObserver(slicer.mrmlScene, slicer.mrmlScene.EndCloseEvent, self.onSceneEndClose)

# Buttons
self.ui.pbUpgrade.connect('clicked(bool)', lambda: self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...'))
self.ui.pbSendImage.connect('clicked(bool)', lambda: self.logic.sendImage())
self.ui.pbSegment.connect('clicked(bool)', lambda: self.logic.applySegmentation())

Expand Down Expand Up @@ -289,6 +296,15 @@ def makeROI2D(self):
roiNode.SetSize(roi_size[0], roi_size[1], 1)
roi_center = np.array(roiNode.GetCenter())
roiNode.SetCenter([roi_center[0], roi_center[1], slicer.app.layoutManager().sliceWidget("Red").sliceLogic().GetSliceOffset()])

def toggleLocalInstall(self, checkbox, file_selector):
import ctk
file_selector.currentPath = ''
if checkbox.isChecked():
file_selector.filters = ctk.ctkPathLineEdit.Files
file_selector.nameFilters = ['server_essentials.zip']
else:
file_selector.filters = ctk.ctkPathLineEdit.Dirs


def cleanup(self) -> None:
Expand Down Expand Up @@ -454,18 +470,19 @@ def pip_install_wrapper(self, command, event):
slicer.util.pip_install(command)
event.set()

def download_wrapper(self, url, filename, event):
with urlopen(url) as r:
# self.setTotalProgress.emit(int(r.info()["Content-Length"]))
with open(filename, "wb") as f:
while True:
chunk = r.read(1024)
if chunk is None:
continue
elif chunk == b"":
break
f.write(chunk)

def download_wrapper(self, url, filename, download_needed, event):
if download_needed:
with urlopen(url) as r:
# self.setTotalProgress.emit(int(r.info()["Content-Length"]))
with open(filename, "wb") as f:
while True:
chunk = r.read(1024)
if chunk is None:
continue
elif chunk == b"":
break
f.write(chunk)

with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall(os.path.dirname(filename))

Expand All @@ -476,15 +493,26 @@ def install_dependencies(self, ctk_path):
print('Installation path is empty')
return

print('Installation will happen in %s'%ctk_path.currentPath)
self.widget.write_setting(ctk_path.currentPath)
if os.path.isfile(ctk_path.currentPath) and os.path.basename(ctk_path.currentPath) == 'server_essentials.zip':
install_path = os.path.abspath(os.path.dirname(ctk_path.currentPath))
download_needed = False
elif os.path.isdir(ctk_path.currentPath):
install_path = ctk_path.currentPath
download_needed = True
else:
print('Invalid installation path')
return

print('Installation will happen in %s'%install_path)

self.widget.write_setting(install_path)

file_url = 'https://github.com/rasakereh/medsam-3dslicer/raw/master/server_essentials.zip'
filename = os.path.join(ctk_path.currentPath, 'server_essentials.zip')
filename = os.path.join(install_path, 'server_essentials.zip')

self.run_on_background(self.download_wrapper, (file_url, filename), 'Downloading additional files...')
self.run_on_background(self.download_wrapper, (file_url, filename, download_needed), 'Downloading additional files...')

self.server_dir = os.path.join(ctk_path.currentPath + '/', 'server_essentials')
self.server_dir = os.path.join(install_path + '/', 'server_essentials')

dependencies = {
'PyTorch': 'torch==2.0.1 torchvision==0.15.2',
Expand All @@ -498,6 +526,41 @@ def install_dependencies(self, ctk_path):
self.run_on_background(self.pip_install_wrapper, (dependencies[dependency],), 'Installing dependencies: %s'%dependency)


def upgrade(self, download, event):
try:
self.progressbar.setLabelText('Checking for updates...')
latest_version_req = requests.get('https://github.com/bowang-lab/MedSAMSlicer/releases/latest')
latest_version = latest_version_req.url.split('/')[-1]
latest_version = float(latest_version[1:])
curr_version = float(MEDSAMLITE_VERSION[1:])
print('Latest version identified:', latest_version)
print('Current version is:', curr_version)
if latest_version > curr_version:
print('Upgrade available')
github_base = 'https://raw.githubusercontent.com/bowang-lab/MedSAMSlicer/v%.2f/'%latest_version
server_url = github_base + 'server/server.py'
module_url = github_base + 'MedSAM/MedSAMLite/MedSAMLite.py'

self.progressbar.setLabelText('Downloading updates...')
server_req = requests.get(server_url)
module_req = requests.get(module_url)

with open(os.path.join(self.server_dir, 'server.py'), 'w') as server_file:
server_file.write(server_req.text)
with open(__file__, 'w') as module_file:
module_file.write(module_req.text)
self.progressbar.setLabelText('Upgraded successfully, please restart Slicer.')

else:
self.progressbar.setLabelText('Already using the latest version')
except:
self.progressbar.setLabelText('Error happened while upgrading')

time.sleep(3)

event.set()


def run_on_background(self, target, args, title):
self.progressbar = slicer.util.createProgressDialog(autoClose=False)
self.progressbar.minimum = 0
Expand Down Expand Up @@ -597,6 +660,11 @@ def sendImage(self, serverUrl='http://127.0.0.1:5555', numpyServerAddress=("127.

def inferSegmentation(self, serverUrl='http://127.0.0.1:5555'):
print('sending infer request...')
################ DEBUG MODE ################
if self.volume_node is None:
self.captureImage()
################ DEBUG MODE ################

slice_idx, bbox, zrange = self.get_bounding_box()

response = requests.post(f'{serverUrl}/infer', json={"slice_idx": slice_idx, "bbox": bbox, "zrange": zrange})
Expand Down Expand Up @@ -671,8 +739,7 @@ def get_bounding_box(self):
return slice_idx, bbox, zrange

def preprocess_CT(self, win_level=40.0, win_width=400.0):
if self.image_data is None:
self.captureImage()
self.captureImage()
# self.volume_node.GetDisplayNode().SetThreshold(0, 255)
# self.volume_node.GetDisplayNode().ApplyThresholdOn()

Expand All @@ -687,8 +754,7 @@ def preprocess_CT(self, win_level=40.0, win_width=400.0):
return image_data_pre

def preprocess_MR(self, lower_percent=0.5, upper_percent=99.5):
if self.image_data is None:
self.captureImage()
self.captureImage()

lower_bound, upper_bound = np.percentile(self.image_data[self.image_data > 0], lower_percent), np.percentile(self.image_data[self.image_data > 0], upper_percent)
image_data_pre = np.clip(self.image_data, lower_bound, upper_bound)
Expand Down
21 changes: 14 additions & 7 deletions MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@
</rect>
</property>
<layout class="QVBoxLayout" name="verticalLayout">
<item>
<widget class="QPushButton" name="pbUpgrade">
<property name="text">
<string>Upgrade Module</string>
</property>
</widget>
</item>
<item>
<widget class="ctkCollapsibleButton" name="clbtnPreprocess">
<property name="text">
<string>Prepare Data</string>
</property>
<layout class="QGridLayout" name="gridLayout_2">
<item row="0" column="0">
<widget class="QLabel" name="lblPrep">
<property name="text">
<string>Preprocessing Options:</string>
</property>
</widget>
</item>
<item row="0" column="1">
<widget class="QPushButton" name="pbCTprep">
<property name="toolTip">
Expand Down Expand Up @@ -56,6 +56,13 @@
</property>
</widget>
</item>
<item row="0" column="0">
<widget class="QLabel" name="lblPrep">
<property name="text">
<string>Preprocessing Options:</string>
</property>
</widget>
</item>
</layout>
</widget>
</item>
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ You can watch a video tutorial of installation steps [here](https://youtu.be/qjs
6. `Choose a folder` to install module dependencies and click on `Install dependencies`. It can take several minutes.
7. Restart 3D Slicer.

**To update to a newer version:** Remove all pre-existing files from both step#2 and step#6 and install the new version as instructed before.
## Upgrade

**If you have version <= v0.02 installed:** Remove all pre-existing files from both step#2 and step#6 and install the new version as instructed before.

**If you have version > v0.02 installed:** Use the *Upgrade Module* button at the top of the module interface to check for and install new updates.

## Usage

Expand Down
4 changes: 3 additions & 1 deletion server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def get_image(wmin: int, wmax: int):
# ), f"Accept either 1 channel gray image or 3 channel rgb. Got image shape {arr.shape} "
image = arr
# H, W = arr.shape[1:] # TODO: make sure h, w not filpped #################### This line is causing problem
W, H = arr.shape[1:]
H, W = arr.shape[1:]
print('Line 225, H, W:', H, W, file=sys.stderr)

for slice_idx in range(image.shape[0]):
# for slice_idx in tqdm(range(4)):
Expand Down Expand Up @@ -266,6 +267,7 @@ def get_bbox1024(mask_1024, bbox_shift=3):
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = mask_1024.shape
print('Line 269, H, W:', H, W, file=sys.stderr)
x_min = max(0, x_min - bbox_shift)
x_max = min(W, x_max + bbox_shift)
y_min = max(0, y_min - bbox_shift)
Expand Down
Binary file modified server_essentials.zip
Binary file not shown.

0 comments on commit e7048eb

Please sign in to comment.