Skip to content

Commit

Permalink
Merge pull request #6 from jhj0517/feature/invert-mask
Browse files Browse the repository at this point in the history
Add "Invert masks" option
  • Loading branch information
jhj0517 authored Aug 28, 2024
2 parents 998201e + 1b5d47b commit c9c163b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
9 changes: 5 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def launch(self):
nb_pixel_size = gr.Number(label="Pixel Size", interactive=True, minimum=1,
visible=self.default_filter == PIXELIZE_FILTER,
value=self.default_pixel_size)
cb_invert_mask = gr.Checkbox(label="invert mask", value=_mask_hparams["invert_mask"])
btn_generate_preview = gr.Button("GENERATE PREVIEW")

with gr.Row():
Expand All @@ -157,7 +158,7 @@ def launch(self):
nb_pixel_size])

preview_params = [vid_frame_prompter, dd_filter_mode, sld_frame_selector, nb_pixel_size,
cp_color_picker]
cp_color_picker, cb_invert_mask]
btn_generate_preview.click(fn=self.sam_inf.add_filter_to_preview,
inputs=preview_params,
outputs=[img_preview])
Expand All @@ -180,6 +181,7 @@ def launch(self):
choices=self.image_modes)
dd_models = gr.Dropdown(label="Model", value=DEFAULT_MODEL_TYPE,
choices=self.sam_inf.available_models)
cb_invert_mask = gr.Checkbox(label="invert mask", value=_mask_hparams["invert_mask"])

with gr.Accordion("Mask Parameters", open=False, visible=self.default_mode == AUTOMATIC_MODE) as acc_mask_hparams:
mask_hparams_component = self.mask_generation_parameters(_mask_hparams)
Expand All @@ -194,10 +196,9 @@ def launch(self):
output_file = gr.File(label="Generated psd file", scale=9)
btn_open_folder = gr.Button("📁\nOpen PSD folder", scale=1)

sources = [img_input, img_input_prompter, dd_input_modes]
model_params = [dd_models]
input_params = [img_input, img_input_prompter, dd_input_modes, dd_models, cb_invert_mask]
mask_hparams = mask_hparams_component + [cb_multimask_output]
input_params = sources + model_params + mask_hparams
input_params += mask_hparams

btn_generate.click(fn=self.sam_inf.divide_layer,
inputs=input_params, outputs=[gallery_output, output_file])
Expand Down
1 change: 1 addition & 0 deletions configs/default_hparams.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ mask_hparams:
min_mask_region_area: 25.0
use_m2m: true
multimask_output: true
invert_mask: false
7 changes: 6 additions & 1 deletion modules/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def decode_to_mask(seg: np.ndarray[np.bool_] | np.ndarray[np.uint8]) -> np.ndarr
return seg.astype(np.uint8)


def invert_masks(masks: List[Dict]) -> List[Dict]:
"""Invert the masks. Used for background masking"""
inverted = 1 - masks
return inverted


def generate_random_color() -> Tuple[int, int, int]:
"""Generate random color in RGB format"""
h = np.random.randint(0, 360)
Expand Down Expand Up @@ -47,7 +53,6 @@ def create_mask_layers(
List of RGBA images
"""
layer_list = []

sorted_masks = sorted(masks, key=lambda x: x['area'], reverse=True)

for info in sorted_masks:
Expand Down
33 changes: 30 additions & 3 deletions modules/sam_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from modules.paths import (MODELS_DIR, TEMP_OUT_DIR, TEMP_DIR, MODEL_CONFIGS, OUTPUT_DIR)
from modules.constants import (BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER, IMAGE_FILE_EXT)
from modules.mask_utils import (
invert_masks,
save_psd_with_masks,
create_mask_combined_images,
create_mask_gallery,
Expand Down Expand Up @@ -129,13 +130,15 @@ def init_video_inference_state(self,
def generate_mask(self,
image: np.ndarray,
model_type: str,
invert_mask: bool = False,
**params) -> List[Dict[str, Any]]:
"""
Generate masks with Automatic segmentation. Default hyperparameters are in './configs/default_hparams.yaml.'
Args:
image (np.ndarray): The input image.
model_type (str): The model type to load.
invert_mask (bool): Invert the mask output - used for background masking.
**params: The hyperparameters for the mask generator.
Returns:
Expand All @@ -154,6 +157,11 @@ def generate_mask(self,
except Exception as e:
logger.exception(f"Error while auto generating masks : {e}")
raise RuntimeError(f"Failed to generate masks") from e

if invert_mask:
generated_masks = [{'segmentation': invert_masks(mask['segmentation']),
'area': mask['area']} for mask in generated_masks]

return generated_masks

def predict_image(self,
Expand All @@ -162,6 +170,7 @@ def predict_image(self,
box: Optional[np.ndarray] = None,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
invert_mask: bool = False,
**params) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Predict image with prompt data.
Expand All @@ -172,6 +181,7 @@ def predict_image(self,
box (np.ndarray): The box prompt data.
point_coords (np.ndarray): The point coordinates prompt data.
point_labels (np.ndarray): The point labels prompt data.
invert_mask (bool): Invert the mask output - used for background masking.
**params: The hyperparameters for the mask generator.
Returns:
Expand All @@ -195,6 +205,10 @@ def predict_image(self,
except Exception as e:
logger.exception(f"Error while predicting image with prompt: {str(e)}")
raise RuntimeError(f"Failed to predict image with prompt") from e

if invert_mask:
masks = invert_masks(masks)

return masks, scores, logits

def add_prediction_to_frame(self,
Expand Down Expand Up @@ -291,6 +305,7 @@ def add_filter_to_preview(self,
frame_idx: int,
pixel_size: Optional[int] = None,
color_hex: Optional[str] = None,
invert_mask: bool = False
):
"""
Add filter to the preview image with the prompt data. Specially made for gradio app.
Expand All @@ -302,6 +317,7 @@ def add_filter_to_preview(self,
frame_idx (int): The frame index of the video.
pixel_size (int): The pixel size for the pixelize filter.
color_hex (str): The color hex code for the solid color filter.
invert_mask (bool): Invert the mask output - used for background masking.
Returns:
np.ndarray: The filtered image output.
Expand Down Expand Up @@ -332,6 +348,9 @@ def add_filter_to_preview(self,
box=box
)
masks = (logits[0] > 0.0).cpu().numpy()
if invert_mask:
masks = invert_masks(masks)

generated_masks = self.format_to_auto_result(masks)

if filter_mode == COLOR_FILTER:
Expand All @@ -347,7 +366,8 @@ def create_filtered_video(self,
filter_mode: str,
frame_idx: int,
pixel_size: Optional[int] = None,
color_hex: Optional[str] = None
color_hex: Optional[str] = None,
invert_mask: bool = False
):
"""
Create a whole filtered video with video_inference_state. Currently only one frame tracking is supported.
Expand All @@ -359,6 +379,7 @@ def create_filtered_video(self,
frame_idx (int): The frame index of the video.
pixel_size (int): The pixel size for the pixelize filter.
color_hex (str): The color hex code for the solid color filter.
invert_mask (bool): Invert the mask output - used for background masking.
Returns:
str: The output video path.
Expand Down Expand Up @@ -390,12 +411,14 @@ def create_filtered_video(self,
inference_state=self.video_inference_state,
points=point_coords,
labels=point_labels,
box=box
box=box,
)

video_segments = self.propagate_in_video(inference_state=self.video_inference_state)
for frame_index, info in video_segments.items():
orig_image, masks = info["image"], info["mask"]
if invert_mask:
masks = invert_masks(masks)
masks = self.format_to_auto_result(masks)

if filter_mode == COLOR_FILTER:
Expand Down Expand Up @@ -423,6 +446,7 @@ def divide_layer(self,
image_prompt_input_data: Dict,
input_mode: str,
model_type: str,
invert_mask: bool = False,
*params):
"""
Divide the layer with the given prompt data and save psd file.
Expand All @@ -432,6 +456,7 @@ def divide_layer(self,
image_prompt_input_data (Dict): The image prompt data.
input_mode (str): The input mode for the image prompt data. ["Automatic", "Box Prompt"]
model_type (str): The model type to load.
invert_mask (bool): Invert the mask output.
*params: The hyperparameters for the mask generator.
Returns:
Expand Down Expand Up @@ -463,6 +488,7 @@ def divide_layer(self,
generated_masks = self.generate_mask(
image=image,
model_type=model_type,
invert_mask=invert_mask,
**hparams
)

Expand All @@ -481,7 +507,8 @@ def divide_layer(self,
box=box,
point_coords=point_coords,
point_labels=point_labels,
multimask_output=hparams["multimask_output"]
multimask_output=hparams["multimask_output"],
invert_mask=invert_mask
)
generated_masks = self.format_to_auto_result(predicted_masks)

Expand Down

0 comments on commit c9c163b

Please sign in to comment.