diff --git a/README.md b/README.md index 5c72d6a..3335440 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,24 @@ # segment_anything_gui + + +## 使用方法 + 这是一个类似PS的抠图工具,支持cpu和英伟达gpu。推荐opencv-python版本4.5.5.64(4.x大概都能跑) 使用方法 +0.打开程序文件,修改配置 + +``` + input_dir = r'G:\xiaowu-pic\133_select_new' + output_dir = r'G:\xiaowu-pic\133_new_segment' + skip_dir = r'G:\xiaowu-pic\133_done' # 需要跳动的文件(可以放原文件,也可以放程序输出的文件) + crop_mode = False # 是否裁剪到最小范围 + # alpha_channel = True # alpha_channel是否保留透明通道 + save_background_img = True # 是否同步生成白背景图片 +``` + 1.将待抠图的图片放到input文件夹中,然后启动程序。 https://github.com/facebookresearch/segment-anything @@ -13,20 +28,29 @@ https://github.com/facebookresearch/segment-anything 在图像上左键单击选择前景点(绿色),右键单击选择背景点(红色)。 -按下a或d键切换到上一张或下一张图片。按下空格键清除所有选点和mask。按下q键删除最后一个选点。 +按下`a`或d``键切换到上一张或下一张图片。按下空格键清除所有选点和mask。按下q键删除最后一个选点。 -按下s键保存抠图结果(如果有生成过Mask的话)。 +按下`s`键保存抠图结果(如果有生成过Mask的话)。 + +按下 `+`或`-`放大或缩小窗口 + +按下`e`键隐藏提示信息 3.Mask选取模式: -按下w键使用模型进行预测,进入Mask选取模式。 +按下`w`键使用模型进行预测,进入Mask选取模式。 + +在Mask选取模式下,可以按下`a`和`d`键切换不同的Mask。 + +按下`s`键保存抠图结果。 -在Mask选取模式下,可以按下a和d键切换不同的Mask。 +按下`w`键返回选点模式,下次模型将会在此mask基础上进行预测 -按下s键保存抠图结果。 +按下 `+`或`-`放大或缩小窗口 + +按下`e`键隐藏提示信息 -按下w键返回选点模式,下次模型将会在此mask基础上进行预测 4.返回选点模式,迭代优化选点 @@ -37,6 +61,12 @@ https://github.com/facebookresearch/segment-anything 程序将在output文件夹中生成抠好的图片,新切割出来图片的文件名会自增。 +--- +## **也可以直接将seg5.py文件复制到segment_anything项目下使用** +## **You can also directly copy the seg5.py file to the segment_anything project for use** +--- + + Segment-Anything-GUI @@ -44,6 +74,16 @@ This is a Photoshop-like image segmentation and extraction tool, supporting both How to Use +Open the program file and modify the configuration +``` + input_dir = r'G:\xiaowu-pic\133_select_new' + output_dir = r'G:\xiaowu-pic\133_new_segment' + skip_dir = r'G:\xiaowu-pic\133_done' # Files that need to be jumped (you can put the original file or the file output by the program) + crop_mode = False # Whether to crop to the minimum range + # alpha_channel = True # Whether alpha_channel retains the transparent channel + save_background_img = True # Whether to generate white background images simultaneously +``` + Place the images to be segmented into the input folder, then run the program. Point selection mode (one by one, multiple points at once may have average results): @@ -52,23 +92,31 @@ Left-click on the image to select foreground points (green). Right-click to select background points (red). -Press the a or d key to switch to the previous or next image. +Press the `a` or `d` key to switch to the previous or next image. Press the spacebar to clear all selected points and masks. -Press the q key to delete the last selected point. +Press the `q` key to delete the last selected point. -Press the s key to save the segmentation result (if a mask has been generated). +Press the `s` key to save the segmentation result (if a mask has been generated). + +Press `+` or `-` to zoom in or out of the window + +Press the `e` key to hide the prompt message Mask selection mode: -Press the w key to use the model for prediction and enter the mask selection mode. +Press the `w` key to use the model for prediction and enter the mask selection mode. In the mask selection mode, you can press the a and d keys to switch between different masks. -Press the s key to save the segmentation result. +Press the `s` key to save the segmentation result. + +Press the `w` key to return to point selection mode. The model will predict based on this mask the next time. + +Press `+` or `-` to zoom in or out of the window -Press the w key to return to point selection mode. The model will predict based on this mask the next time. +Press the `e` key to hide the prompt message Return to point selection mode to iteratively optimize selected points: diff --git a/seg5.py b/seg5.py index 09524d6..b90d7aa 100644 --- a/seg5.py +++ b/seg5.py @@ -3,101 +3,149 @@ import numpy as np from segment_anything import sam_model_registry, SamPredictor -input_dir = 'input' -output_dir = 'output' -crop_mode=True#是否裁剪到最小范围 -#alpha_channel是否保留透明通道 +input_dir = r'G:\xiaowu-pic\133_select_new' +output_dir = r'G:\xiaowu-pic\133_new_segment' +skip_dir = r'G:\xiaowu-pic\133_done' +crop_mode = False # 是否裁剪到最小范围 +# alpha_channel = True # alpha_channel是否保留透明通道 +save_background_img = False # 是否同步生成白背景图片 + print('最好是每加一个点就按w键predict一次') os.makedirs(output_dir, exist_ok=True) -image_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg','.JPG','.JPEG','.PNG'))] - -sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") -_ = sam.to(device="cuda")#注释掉这一行,会用cpu运行,速度会慢很多 +image_files = [f for f in os.listdir(input_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG'))] +skip_files = [] +if os.path.exists(skip_dir) and os.path.isdir(skip_dir): + skip_files = [f for f in os.listdir(skip_dir) if + f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG'))] + skip_files = [f[:f.rfind('_')] for f in skip_files] + image_files = [f for f in image_files if f[:f.rfind('.')] not in skip_files] + +sam = sam_model_registry["vit_h"](checkpoint="./checkpoint/sam_vit_h_4b8939.pth") +_ = sam.to(device="cuda") # 注释掉这一行,会用cpu运行,速度会慢很多 predictor = SamPredictor(sam) + + def mouse_click(event, x, y, flags, param): global input_point, input_label, input_stop + + x = round(x / zoom_rate) + y = round(y / zoom_rate) if not input_stop: - if event == cv2.EVENT_LBUTTONDOWN : + if event == cv2.EVENT_LBUTTONDOWN: input_point.append([x, y]) input_label.append(1) - elif event == cv2.EVENT_RBUTTONDOWN : + elif event == cv2.EVENT_RBUTTONDOWN: input_point.append([x, y]) input_label.append(0) else: - if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN : + if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: print('此时不能添加点,按w退出mask选择模式') def apply_mask(image, mask, alpha_channel=True): - if alpha_channel: - alpha = np.zeros_like(image[..., 0]) - alpha[mask == 1] = 255 - image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) - else: - image = np.where(mask[..., None] == 1, image, 0) - return image + image = np.where(mask[..., None] == 1, image, 0) + images = [image] + if save_background_img: + image_white = np.where(mask[..., None] == 1, image, 255) + images.append(image_white) + + for i, im in enumerate(images): + if alpha_channel and im.shape[-1] == 3: + alpha = np.zeros_like(im[..., 0]) + alpha[mask == 1] = 255 + images[i] = cv2.merge((im[..., 0], im[..., 1], im[..., 2], alpha)) + return images + -def apply_color_mask(image, mask, color, color_dark = 0.5): +def apply_color_mask(image, mask, color, color_dark=0.5): for c in range(3): image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c]) return image + def get_next_filename(base_path, filename): name, ext = os.path.splitext(filename) + for i in range(1, 101): new_name = f"{name}_{i}{ext}" - if not os.path.exists(os.path.join(base_path, new_name)): - return new_name + names = [new_name] + if save_background_img: + new_name_white = f"{name}_{i}w{ext}" + names.append(new_name_white) + if not os.path.exists(os.path.join(base_path, names[0])) and (len(names) == 1 or not os.path.exists( + os.path.join(base_path, names[1]))): + return names return None + def save_masked_image(image, mask, output_dir, filename, crop_mode_): if crop_mode_: y, x = np.where(mask) y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max() - cropped_mask = mask[y_min:y_max+1, x_min:x_max+1] - cropped_image = image[y_min:y_max+1, x_min:x_max+1] - masked_image = apply_mask(cropped_image, cropped_mask) + cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1] + cropped_image = image[y_min:y_max + 1, x_min:x_max + 1] + masked_images = apply_mask(cropped_image, cropped_mask) else: - masked_image = apply_mask(image, mask) - filename = filename[:filename.rfind('.')]+'.png' - new_filename = get_next_filename(output_dir, filename) - - if new_filename: - if masked_image.shape[-1] == 4: - cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) - else: - cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) - print(f"Saved as {new_filename}") + masked_images = apply_mask(image, mask) + + filename = filename[:filename.rfind('.')] + '.png' + new_filenames = get_next_filename(output_dir, filename) + + if new_filenames and masked_images: + for new_filename, masked_image in zip(new_filenames, masked_images): + if masked_image.shape[-1] == 4: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9]) + else: + cv2.imwrite(os.path.join(output_dir, new_filename), masked_image) + print(f"Saved as {new_filename}") else: print("Could not save the image. Too many variations exist.") + +def show_infos(info, show, image_to_show): + if not show: + return + + for i, inf in enumerate(info): + cv2.putText(image_to_show, inf, (10, 30 + 40 * i), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + + current_index = 0 cv2.namedWindow("image") cv2.setMouseCallback("image", mouse_click) input_point = [] input_label = [] -input_stop=False +input_stop = False +zoom_rate = 1 +show_info = True while True: filename = image_files[current_index] - image_orign = cv2.imread(os.path.join(input_dir, filename)) + image_orign = cv2.imread(os.path.join(input_dir, filename), cv2.IMREAD_UNCHANGED) image_crop = image_orign.copy() image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) selected_mask = None - logit_input= None + logit_input = None while True: - #print(input_point) - input_stop=False + # print(input_point) + input_stop = False image_display = image_orign.copy() - display_info = f'{filename} | Press s to save | Press w to predict | Press d to next image | Press a to previous image | Press space to clear | Press q to remove last point ' - cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) + + display_info = f'{filename} | Press s to save | Press w to predict' + display_info1 = f'| Press a to previous image | Press d to next image | Press space to clear' + display_info2 = f'| Press q to remove last point | Press +,- to zoom in or out | Press e to hide suggestion info' + infos = [display_info, display_info1, display_info2] + show_infos(infos, show_info, image_display) + for point, label in zip(input_point, input_label): color = (0, 255, 0) if label == 1 else (0, 0, 255) cv2.circle(image_display, tuple(point), 5, color, -1) - if selected_mask is not None : + if selected_mask is not None: color = tuple(np.random.randint(0, 256, 3).tolist()) - selected_image = apply_color_mask(image_display,selected_mask, color) + selected_image = apply_color_mask(image_display, selected_mask, color) + image_display = cv2.resize(image_display, None, fx=zoom_rate, fy=zoom_rate, interpolation=cv2.INTER_LINEAR) cv2.imshow("image", image_display) key = cv2.waitKey(1) @@ -105,60 +153,75 @@ def save_masked_image(image, mask, output_dir, filename, crop_mode_): input_point = [] input_label = [] selected_mask = None - logit_input= None + logit_input = None elif key == ord("w"): - input_stop=True + input_stop = True if len(input_point) > 0 and len(input_label) > 0: - + predictor.set_image(image) input_point_np = np.array(input_point) input_label_np = np.array(input_label) - masks, scores, logits= predictor.predict( + masks, scores, logits = predictor.predict( point_coords=input_point_np, point_labels=input_label_np, mask_input=logit_input[None, :, :] if logit_input is not None else None, multimask_output=True, ) - mask_idx=0 + mask_idx = 0 num_masks = len(masks) - while(1): + while (1): color = tuple(np.random.randint(0, 256, 3).tolist()) image_select = image_orign.copy() - selected_mask=masks[mask_idx] - selected_image = apply_color_mask(image_select,selected_mask, color) - mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | Press w to confirm | Press d to next mask | Press a to previous mask | Press q to remove last point | Press s to save' - cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA) - + selected_mask = masks[mask_idx] + selected_image = apply_color_mask(image_select, selected_mask, color) + + mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | Press w ' \ + f'to confirm' + mask_info1 = f'| Press d to next mask | Press a to previous mask | Press q to remove last point ' + mask_info2 = f'| Press s to save | Press +,- to zoom in or out | Press e to hide suggestion info' + infos = [mask_info, mask_info1, mask_info2] + show_infos(infos, show_info, selected_image) + + selected_image = cv2.resize(selected_image, None, fx=zoom_rate, fy=zoom_rate, + interpolation=cv2.INTER_LINEAR) cv2.imshow("image", selected_image) - key=cv2.waitKey(10) - if key == ord('q') and len(input_point)>0: + key = cv2.waitKey(10) + if key == ord('q') and len(input_point) > 0: input_point.pop(-1) input_label.pop(-1) elif key == ord('s'): save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) - elif key == ord('a') : - if mask_idx>0: - mask_idx-=1 + elif key == ord('a'): + if mask_idx > 0: + mask_idx -= 1 else: - mask_idx=num_masks-1 - elif key == ord('d') : - if mask_idx 0.05: + zoom_rate -= 0.05 + elif key == ord('e'): + show_info = not show_info + + logit_input = logits[mask_idx, :, :] + print('max score:', np.argmax(scores), ' select:', mask_idx) elif key == ord('a'): current_index = max(0, current_index - 1) @@ -172,11 +235,18 @@ def save_masked_image(image, mask, output_dir, filename, crop_mode_): break elif key == 27: break - elif key == ord('q') and len(input_point)>0: + elif key == ord('q') and len(input_point) > 0: input_point.pop(-1) input_label.pop(-1) - elif key == ord('s') and selected_mask is not None : + elif key == ord('s') and selected_mask is not None: save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode) + elif key == ord('+'): + zoom_rate += 0.05 + elif key == ord('-'): + if zoom_rate > 0.05: + zoom_rate -= 0.05 + elif key == ord('e'): + show_info = not show_info if key == 27: break