Skip to content

Commit

Permalink
Merge pull request #10 from jhj0517/feature/update-ci
Browse files Browse the repository at this point in the history
Add CI / CD actions
  • Loading branch information
jhj0517 authored Oct 8, 2024
2 parents ff62198 + d6d2f20 commit 6ba1785
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 11 deletions.
5 changes: 5 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## Related issues / PRs
- #

## Summarize Changes
1.
38 changes: 38 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: CI

on:
workflow_dispatch:

push:
branches:
- master
pull_request:
branches:
- master

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python: ["3.10"]

steps:
- name: Clean up space for action
run: rm -rf /opt/hostedtoolcache

- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}

- name: Install git and ffmpeg
run: sudo apt-get update && sudo apt-get install -y git ffmpeg

- name: Install dependencies
run: pip install -r requirements.txt pytest

- name: Run test
run: python -m pytest -rs tests
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
outputs/
models/
venv/
.pytest_cache/
__pycache__/
*.png
!docs/example_image_segmentation.png
!docs/example_psd_file.png
Expand Down
27 changes: 16 additions & 11 deletions modules/sam_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def add_prediction_to_frame(self,
return out_frame_idx, out_obj_ids, out_mask_logits

def propagate_in_video(self,
inference_state: Optional[Dict] = None,):
inference_state: Optional[Dict] = None,) -> Dict:
"""
Propagate in the video with the tracked predictions for each frame. Currently only supports
single frame tracking.
Expand Down Expand Up @@ -324,13 +324,18 @@ def add_filter_to_preview(self,
logger.exception("Error while adding filter to preview, load video predictor first")
raise f"Error while adding filter to preview"

if not image_prompt_input_data["points"]:
image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
if not prompt:
error_message = ("No prompt data provided. If this is an incorrect flag, "
"Please press the eraser button (on the image prompter) and add your prompts again.")
logger.error(error_message)
raise gr.Error(error_message, duration=20)

image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]
if not image:
error_message = "No image data provided."
logger.error(error_message)
raise gr.Error(error_message, duration=20)

image = np.array(image.convert("RGB"))

point_labels, point_coords, box = self.handle_prompt_data(prompt)
Expand Down Expand Up @@ -372,7 +377,7 @@ def create_filtered_video(self,
This needs FFmpeg to run. Returns two output path because of the gradio app.
Args:
image_prompt_input_data (Dict): The image prompt data.
image_prompt_input_data (Dict): The image prompt data with "image" and "points" keys.
filter_mode (str): The filter mode to apply. ["Solid Color", "Pixelize"]
frame_idx (int): The frame index of the video.
pixel_size (int): The pixel size for the pixelize filter.
Expand All @@ -388,21 +393,21 @@ def create_filtered_video(self,
logger.exception("Error while adding filter to preview, load video predictor first")
raise RuntimeError("Error while adding filter to preview")

if not image_prompt_input_data["points"]:
prompt = image_prompt_input_data["points"]
if not prompt:
error_message = ("No prompt data provided. If this is an incorrect flag, "
"Please press the eraser button (on the image prompter) and add your prompts again.")
logger.error(error_message)
raise gr.Error(error_message, duration=20)

point_labels, point_coords, box = self.handle_prompt_data(prompt)
obj_id = frame_idx

output_dir = os.path.join(self.output_dir, "filter")

clean_files_with_extension(TEMP_OUT_DIR, IMAGE_FILE_EXT)
self.video_predictor.reset_state(self.video_inference_state)

prompt_frame_image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"]

point_labels, point_coords, box = self.handle_prompt_data(prompt)
obj_id = frame_idx

idx, scores, logits = self.add_prediction_to_frame(
frame_idx=frame_idx,
obj_id=obj_id,
Expand Down Expand Up @@ -536,7 +541,7 @@ def handle_prompt_data(
Handle data from ImageInputPrompter.
Args:
prompt_data (Dict): A dictionary containing the 'prompt' key with a list of prompts.
prompt_data (Dict): A dictionary containing the 'points' key with a list of prompts.
Returns:
point_labels (List): list of points labels.
Expand Down
67 changes: 67 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os.path
import requests
import numpy as np
import subprocess
import torch

from modules.paths import *

TEST_MODEL = "sam2.1_hiera_tiny"
TEST_VIDEO_URL = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerEscapes.mp4"
TEST_IMAGE_URL = "https://raw.githubusercontent.com/test-images/png/refs/heads/main/202105/cs-blue-00f.png"
TEST_VIDEO_PATH = os.path.join(WEBUI_DIR, "tests", "test_video.mp4")
TEST_IMAGE_PATH = os.path.join(WEBUI_DIR, "tests", "test_image.png")
TEST_POINTS = np.array([[850., 424.],
[810., 510.]])
TEST_LABELS = np.array([1., 1.])
TEST_BOX = np.array([[662., 197., 1056., 711.]])
TEST_GRADIO_PROMPT_BOX = [[653.0, 163.0, 2.0, 1132.0, 702.0, 3.0]]
TEST_GRADIO_PROMPT_POINTS = [[817.0, 366.0, 1.0, 0.0, 0.0, 4.0], [864.0, 533.0, 1.0, 0.0, 0.0, 4.0]]


def download_test_sam_model(model_name: str):
model_path = os.path.join(MODELS_DIR, model_name) + ".pt"
if os.path.exists(model_path):
return

from modules.model_downloader import download_sam_model_url
download_sam_model_url(model_type=model_name, model_dir=MODELS_DIR)


def download_test_files():
if not os.path.exists(TEST_IMAGE_PATH):
download_file(TEST_IMAGE_URL, TEST_IMAGE_PATH)
if not os.path.exists(TEST_VIDEO_PATH):
download_file(TEST_VIDEO_URL, TEST_VIDEO_PATH)
trim_video(TEST_VIDEO_PATH, seconds=1)


def trim_video(video_path, seconds=1):
temp_output_path = video_path + ".temp.mp4"

command = [
"ffmpeg", "-i", video_path, "-t", f"{seconds}", "-c", "copy", temp_output_path
]

try:
subprocess.run(command, check=True)
os.replace(temp_output_path, video_path)
print(f"Trimmed video to {seconds} seconds and saved to {video_path}")
except subprocess.CalledProcessError as e:
print(f"Error trimming video: {e}")


def download_file(url, path):
response = requests.get(url, stream=True)
if response.status_code == 200:
os.makedirs(os.path.dirname(path), exist_ok=True) # Ensure the directory exists
with open(path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Downloaded {url} to {path}")
else:
print(f"Failed to download {url}. Status code: {response.status_code}")


def is_cuda_available():
return torch.cuda.is_available()
57 changes: 57 additions & 0 deletions tests/test_image_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from PIL import Image
import pytest
from typing import Dict, Optional

from test_config import *
import numpy as np
from modules.paths import *
from modules.constants import *
from modules.sam_inference import SamInference


@pytest.mark.skipif(
not is_cuda_available(),
reason="Skipping because this test only works in GPU"
)
@pytest.mark.parametrize(
"model_name,image_path,points,labels,box,multimask_output",
[
(TEST_MODEL, TEST_IMAGE_PATH, TEST_POINTS, TEST_LABELS, TEST_BOX, True)
]
)
def test_image_segmentation(
model_name: str,
image_path: str,
points: np.ndarray,
labels: np.ndarray,
box: np.ndarray,
multimask_output: bool
):
download_test_files()

inferencer = SamInference()
print("Device:", inferencer.device)

image = load_image(image_path)

hparams = {
"multimask_output": multimask_output,
}

masks, scores, logits = inferencer.predict_image(
image=image,
model_type=model_name,
point_coords=points,
point_labels=labels,
**hparams
)

assert isinstance(masks, np.ndarray)
assert isinstance(scores, np.ndarray)
assert isinstance(logits, np.ndarray)


def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image_array = np.array(image)
return image_array
90 changes: 90 additions & 0 deletions tests/test_video_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
from typing import Dict

from test_config import *
import numpy as np
from modules.paths import *
from modules.constants import *
from modules.sam_inference import SamInference


@pytest.mark.skipif(
not is_cuda_available(),
reason="Skipping because this test only works in GPU"
)
@pytest.mark.parametrize(
"model_name,video_path,points,labels,box",
[
(TEST_MODEL, TEST_VIDEO_PATH, TEST_POINTS, TEST_LABELS, TEST_BOX)
]
)
def test_video_segmentation(
model_name: str,
video_path: str,
points: np.ndarray,
labels: np.ndarray,
box: np.ndarray
):
download_test_files()

inferencer = SamInference()
print("Device:", inferencer.device)
inferencer.init_video_inference_state(
vid_input=video_path,
model_type=model_name,
)

inferencer.add_prediction_to_frame(
frame_idx=0,
obj_id=0,
points=TEST_POINTS,
labels=TEST_LABELS,
)

inferencer.add_prediction_to_frame(
frame_idx=1,
obj_id=1,
box=TEST_BOX,
)

video_segments = inferencer.propagate_in_video()

assert video_segments and isinstance(video_segments, Dict)


@pytest.mark.skipif(
not is_cuda_available(),
reason="Skipping because this test only works in GPU"
)
@pytest.mark.parametrize(
"model_name,video_path,gradio_prompt",
[
(TEST_MODEL, TEST_VIDEO_PATH, TEST_GRADIO_PROMPT_BOX)
]
)
def test_filtered_video_creation_pipeline(
model_name: str,
video_path: str,
gradio_prompt: np.ndarray,
):
download_test_files()

inferencer = SamInference()
print("Device:", inferencer.device)
inferencer.init_video_inference_state(
vid_input=video_path,
model_type=model_name,
)
prompt_data = {
"points": gradio_prompt
}

out_path, out_path = inferencer.create_filtered_video(
image_prompt_input_data=prompt_data,
filter_mode=COLOR_FILTER,
frame_idx=0,
color_hex=DEFAULT_COLOR,
invert_mask=True
)

assert os.path.exists(out_path)

0 comments on commit 6ba1785

Please sign in to comment.