diff --git a/MedSAM/MedSAMLite/MedSAMLite.py b/MedSAM/MedSAMLite/MedSAMLite.py index 873ab24..40fb513 100644 --- a/MedSAM/MedSAMLite/MedSAMLite.py +++ b/MedSAM/MedSAMLite/MedSAMLite.py @@ -35,6 +35,8 @@ except: pass # no installation anymore, shorter plugin load +MEDSAMLITE_VERSION = 'v0.03' + # # MedSAMLite # @@ -181,18 +183,22 @@ def setup(self) -> None: DEPENDENCIES_AVAILABLE = False if not DEPENDENCIES_AVAILABLE: - from PythonQt.QtGui import QLabel, QPushButton, QSpacerItem, QSizePolicy + from PythonQt.QtGui import QLabel, QPushButton, QSpacerItem, QSizePolicy, QCheckBox import ctk path_instruction = QLabel('Choose a folder to install module dependencies in') restart_instruction = QLabel('Restart 3D Slicer after all dependencies are installed!') ctk_install_path = ctk.ctkPathLineEdit() ctk_install_path.filters = ctk.ctkPathLineEdit.Dirs + + local_install = QCheckBox("Install from local server_essentials.zip") + local_install.toggled.connect(lambda:self.toggleLocalInstall(local_install, ctk_install_path)) install_btn = QPushButton('Install dependencies') install_btn.clicked.connect(lambda: self.logic.install_dependencies(ctk_install_path)) self.layout.addWidget(path_instruction) + self.layout.addWidget(local_install) self.layout.addWidget(ctk_install_path) self.layout.addWidget(install_btn) self.layout.addWidget(restart_instruction) @@ -233,6 +239,7 @@ def setup(self) -> None: self.addObserver(slicer.mrmlScene, slicer.mrmlScene.EndCloseEvent, self.onSceneEndClose) # Buttons + self.ui.pbUpgrade.connect('clicked(bool)', lambda: self.logic.run_on_background(self.logic.upgrade, (True,), 'Checking for updates...')) self.ui.pbSendImage.connect('clicked(bool)', lambda: self.logic.sendImage()) self.ui.pbSegment.connect('clicked(bool)', lambda: self.logic.applySegmentation()) @@ -289,6 +296,15 @@ def makeROI2D(self): roiNode.SetSize(roi_size[0], roi_size[1], 1) roi_center = np.array(roiNode.GetCenter()) roiNode.SetCenter([roi_center[0], roi_center[1], slicer.app.layoutManager().sliceWidget("Red").sliceLogic().GetSliceOffset()]) + + def toggleLocalInstall(self, checkbox, file_selector): + import ctk + file_selector.currentPath = '' + if checkbox.isChecked(): + file_selector.filters = ctk.ctkPathLineEdit.Files + file_selector.nameFilters = ['server_essentials.zip'] + else: + file_selector.filters = ctk.ctkPathLineEdit.Dirs def cleanup(self) -> None: @@ -454,18 +470,19 @@ def pip_install_wrapper(self, command, event): slicer.util.pip_install(command) event.set() - def download_wrapper(self, url, filename, event): - with urlopen(url) as r: - # self.setTotalProgress.emit(int(r.info()["Content-Length"])) - with open(filename, "wb") as f: - while True: - chunk = r.read(1024) - if chunk is None: - continue - elif chunk == b"": - break - f.write(chunk) - + def download_wrapper(self, url, filename, download_needed, event): + if download_needed: + with urlopen(url) as r: + # self.setTotalProgress.emit(int(r.info()["Content-Length"])) + with open(filename, "wb") as f: + while True: + chunk = r.read(1024) + if chunk is None: + continue + elif chunk == b"": + break + f.write(chunk) + with zipfile.ZipFile(filename, 'r') as zip_ref: zip_ref.extractall(os.path.dirname(filename)) @@ -476,15 +493,26 @@ def install_dependencies(self, ctk_path): print('Installation path is empty') return - print('Installation will happen in %s'%ctk_path.currentPath) - self.widget.write_setting(ctk_path.currentPath) + if os.path.isfile(ctk_path.currentPath) and os.path.basename(ctk_path.currentPath) == 'server_essentials.zip': + install_path = os.path.abspath(os.path.dirname(ctk_path.currentPath)) + download_needed = False + elif os.path.isdir(ctk_path.currentPath): + install_path = ctk_path.currentPath + download_needed = True + else: + print('Invalid installation path') + return + + print('Installation will happen in %s'%install_path) + self.widget.write_setting(install_path) + file_url = 'https://github.com/rasakereh/medsam-3dslicer/raw/master/server_essentials.zip' - filename = os.path.join(ctk_path.currentPath, 'server_essentials.zip') + filename = os.path.join(install_path, 'server_essentials.zip') - self.run_on_background(self.download_wrapper, (file_url, filename), 'Downloading additional files...') + self.run_on_background(self.download_wrapper, (file_url, filename, download_needed), 'Downloading additional files...') - self.server_dir = os.path.join(ctk_path.currentPath + '/', 'server_essentials') + self.server_dir = os.path.join(install_path + '/', 'server_essentials') dependencies = { 'PyTorch': 'torch==2.0.1 torchvision==0.15.2', @@ -498,6 +526,41 @@ def install_dependencies(self, ctk_path): self.run_on_background(self.pip_install_wrapper, (dependencies[dependency],), 'Installing dependencies: %s'%dependency) + def upgrade(self, download, event): + try: + self.progressbar.setLabelText('Checking for updates...') + latest_version_req = requests.get('https://github.com/bowang-lab/MedSAMSlicer/releases/latest') + latest_version = latest_version_req.url.split('/')[-1] + latest_version = float(latest_version[1:]) + curr_version = float(MEDSAMLITE_VERSION[1:]) + print('Latest version identified:', latest_version) + print('Current version is:', curr_version) + if latest_version > curr_version: + print('Upgrade available') + github_base = 'https://raw.githubusercontent.com/bowang-lab/MedSAMSlicer/v%.2f/'%latest_version + server_url = github_base + 'server/server.py' + module_url = github_base + 'MedSAM/MedSAMLite/MedSAMLite.py' + + self.progressbar.setLabelText('Downloading updates...') + server_req = requests.get(server_url) + module_req = requests.get(module_url) + + with open(os.path.join(self.server_dir, 'server.py'), 'w') as server_file: + server_file.write(server_req.text) + with open(__file__, 'w') as module_file: + module_file.write(module_req.text) + self.progressbar.setLabelText('Upgraded successfully, please restart Slicer.') + + else: + self.progressbar.setLabelText('Already using the latest version') + except: + self.progressbar.setLabelText('Error happened while upgrading') + + time.sleep(3) + + event.set() + + def run_on_background(self, target, args, title): self.progressbar = slicer.util.createProgressDialog(autoClose=False) self.progressbar.minimum = 0 @@ -597,6 +660,11 @@ def sendImage(self, serverUrl='http://127.0.0.1:5555', numpyServerAddress=("127. def inferSegmentation(self, serverUrl='http://127.0.0.1:5555'): print('sending infer request...') + ################ DEBUG MODE ################ + if self.volume_node is None: + self.captureImage() + ################ DEBUG MODE ################ + slice_idx, bbox, zrange = self.get_bounding_box() response = requests.post(f'{serverUrl}/infer', json={"slice_idx": slice_idx, "bbox": bbox, "zrange": zrange}) @@ -671,8 +739,7 @@ def get_bounding_box(self): return slice_idx, bbox, zrange def preprocess_CT(self, win_level=40.0, win_width=400.0): - if self.image_data is None: - self.captureImage() + self.captureImage() # self.volume_node.GetDisplayNode().SetThreshold(0, 255) # self.volume_node.GetDisplayNode().ApplyThresholdOn() @@ -687,8 +754,7 @@ def preprocess_CT(self, win_level=40.0, win_width=400.0): return image_data_pre def preprocess_MR(self, lower_percent=0.5, upper_percent=99.5): - if self.image_data is None: - self.captureImage() + self.captureImage() lower_bound, upper_bound = np.percentile(self.image_data[self.image_data > 0], lower_percent), np.percentile(self.image_data[self.image_data > 0], upper_percent) image_data_pre = np.clip(self.image_data, lower_bound, upper_bound) diff --git a/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui b/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui index a5b991d..631acac 100644 --- a/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui +++ b/MedSAM/MedSAMLite/Resources/UI/MedSAMLite.ui @@ -11,19 +11,19 @@ + + + + Upgrade Module + + + Prepare Data - - - - Preprocessing Options: - - - @@ -56,6 +56,13 @@ + + + + Preprocessing Options: + + + diff --git a/README.md b/README.md index 40ad6f4..40cd33e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,11 @@ You can watch a video tutorial of installation steps [here](https://youtu.be/qjs 6. `Choose a folder` to install module dependencies and click on `Install dependencies`. It can take several minutes. 7. Restart 3D Slicer. -**To update to a newer version:** Remove all pre-existing files from both step#2 and step#6 and install the new version as instructed before. +## Upgrade + +**If you have version <= v0.02 installed:** Remove all pre-existing files from both step#2 and step#6 and install the new version as instructed before. + +**If you have version > v0.02 installed:** Use the *Upgrade Module* button at the top of the module interface to check for and install new updates. ## Usage diff --git a/server/server.py b/server/server.py index 67e434a..c84bb22 100644 --- a/server/server.py +++ b/server/server.py @@ -222,7 +222,8 @@ def get_image(wmin: int, wmax: int): # ), f"Accept either 1 channel gray image or 3 channel rgb. Got image shape {arr.shape} " image = arr # H, W = arr.shape[1:] # TODO: make sure h, w not filpped #################### This line is causing problem - W, H = arr.shape[1:] + H, W = arr.shape[1:] + print('Line 225, H, W:', H, W, file=sys.stderr) for slice_idx in range(image.shape[0]): # for slice_idx in tqdm(range(4)): @@ -266,6 +267,7 @@ def get_bbox1024(mask_1024, bbox_shift=3): y_min, y_max = np.min(y_indices), np.max(y_indices) # add perturbation to bounding box coordinates H, W = mask_1024.shape + print('Line 269, H, W:', H, W, file=sys.stderr) x_min = max(0, x_min - bbox_shift) x_max = min(W, x_max + bbox_shift) y_min = max(0, y_min - bbox_shift) diff --git a/server_essentials.zip b/server_essentials.zip index a8698f4..bfdec3f 100644 Binary files a/server_essentials.zip and b/server_essentials.zip differ