Skip to content

Commit

Permalink
Add show_prompts and show_images functions
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Sep 16, 2024
1 parent e4b815f commit 7e5c754
Showing 1 changed file with 102 additions and 2 deletions.
104 changes: 102 additions & 2 deletions samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,7 @@ def save_video_segments_blended(
self,
output_dir: str,
img_ext: str = "png",
alpha: float = 0.6,
dpi: int = 200,
frame_stride: int = 1,
output_video: Optional[str] = None,
Expand All @@ -1236,6 +1237,8 @@ def save_video_segments_blended(
Args:
output_dir (str): The directory to save the output images.
img_ext (str): The file extension for the output images. Defaults to "png".
alpha (float): The alpha value for the blended masks. Defaults to 0.6.
dpi (int): The DPI (dots per inch) for the output images. Defaults to 200.
frame_stride (int): The stride for selecting frames to save. Defaults to 1.
output_video (Optional[str]): The path to the output video file. Defaults to None.
Expand All @@ -1246,11 +1249,11 @@ def save_video_segments_blended(

def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
color = np.array([*cmap(cmap_idx)[:3], alpha])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
Expand Down Expand Up @@ -1306,3 +1309,100 @@ def show_mask(mask, ax, obj_id=None, random_color=False):

if output_video is not None:
common.images_to_video(output_dir, output_video, fps=fps)

def show_images(self, path: str = None) -> None:
"""Show the images in the video.
Args:
path (str, optional): The path to the images. Defaults to None.
"""
if path is None:
path = self.video_path

if path is not None:
common.show_image_gui(path)

def show_prompts(
self,
prompts: Dict[int, Any],
frame_idx: int = 0,
mask: Any = None,
random_color: bool = False,
figsize: Tuple[int, int] = (9, 6),
) -> None:
"""Show the prompts on the image.
Args:
prompts (Dict[int, Any]): A dictionary containing the prompts with
points and labels.
frame_idx (int, optional): The frame index. Defaults to 0.
mask (Any, optional): The mask. Defaults to None.
random_color (bool, optional): Whether to use random colors for the
masks. Defaults to False.
figsize (Tuple[int, int], optional): The figure size. Defaults to (9, 6).
"""

from PIL import Image

def show_mask(mask, ax, obj_id=None, random_color=random_color):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)

def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle(
(x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2
)
)

prompts = self._convert_prompts(prompts)
video_dir = self.video_path
frame_names = self._frame_names
plt.figure(figsize=figsize)
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

for obj_id, prompt in prompts.items():
points = prompt.get("points", None)
labels = prompt.get("labels", None)
box = prompt.get("box", None)
anno_frame_idx = prompt.get("frame_idx", None)
if anno_frame_idx == frame_idx:
if points is not None:
show_points(points, labels, plt.gca())
if box is not None:
show_box(box, plt.gca())
if mask is not None:
show_mask(mask, plt.gca(), obj_id=obj_id)

0 comments on commit 7e5c754

Please sign in to comment.