-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from jhj0517/feature/update-ci
Add CI / CD actions
- Loading branch information
Showing
7 changed files
with
275 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
## Related issues / PRs | ||
- # | ||
|
||
## Summarize Changes | ||
1. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: CI | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
push: | ||
branches: | ||
- master | ||
pull_request: | ||
branches: | ||
- master | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python: ["3.10"] | ||
|
||
steps: | ||
- name: Clean up space for action | ||
run: rm -rf /opt/hostedtoolcache | ||
|
||
- uses: actions/checkout@v4 | ||
- name: Setup Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python }} | ||
|
||
- name: Install git and ffmpeg | ||
run: sudo apt-get update && sudo apt-get install -y git ffmpeg | ||
|
||
- name: Install dependencies | ||
run: pip install -r requirements.txt pytest | ||
|
||
- name: Run test | ||
run: python -m pytest -rs tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os.path | ||
import requests | ||
import numpy as np | ||
import subprocess | ||
import torch | ||
|
||
from modules.paths import * | ||
|
||
TEST_MODEL = "sam2.1_hiera_tiny" | ||
TEST_VIDEO_URL = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerEscapes.mp4" | ||
TEST_IMAGE_URL = "https://raw.githubusercontent.com/test-images/png/refs/heads/main/202105/cs-blue-00f.png" | ||
TEST_VIDEO_PATH = os.path.join(WEBUI_DIR, "tests", "test_video.mp4") | ||
TEST_IMAGE_PATH = os.path.join(WEBUI_DIR, "tests", "test_image.png") | ||
TEST_POINTS = np.array([[850., 424.], | ||
[810., 510.]]) | ||
TEST_LABELS = np.array([1., 1.]) | ||
TEST_BOX = np.array([[662., 197., 1056., 711.]]) | ||
TEST_GRADIO_PROMPT_BOX = [[653.0, 163.0, 2.0, 1132.0, 702.0, 3.0]] | ||
TEST_GRADIO_PROMPT_POINTS = [[817.0, 366.0, 1.0, 0.0, 0.0, 4.0], [864.0, 533.0, 1.0, 0.0, 0.0, 4.0]] | ||
|
||
|
||
def download_test_sam_model(model_name: str): | ||
model_path = os.path.join(MODELS_DIR, model_name) + ".pt" | ||
if os.path.exists(model_path): | ||
return | ||
|
||
from modules.model_downloader import download_sam_model_url | ||
download_sam_model_url(model_type=model_name, model_dir=MODELS_DIR) | ||
|
||
|
||
def download_test_files(): | ||
if not os.path.exists(TEST_IMAGE_PATH): | ||
download_file(TEST_IMAGE_URL, TEST_IMAGE_PATH) | ||
if not os.path.exists(TEST_VIDEO_PATH): | ||
download_file(TEST_VIDEO_URL, TEST_VIDEO_PATH) | ||
trim_video(TEST_VIDEO_PATH, seconds=1) | ||
|
||
|
||
def trim_video(video_path, seconds=1): | ||
temp_output_path = video_path + ".temp.mp4" | ||
|
||
command = [ | ||
"ffmpeg", "-i", video_path, "-t", f"{seconds}", "-c", "copy", temp_output_path | ||
] | ||
|
||
try: | ||
subprocess.run(command, check=True) | ||
os.replace(temp_output_path, video_path) | ||
print(f"Trimmed video to {seconds} seconds and saved to {video_path}") | ||
except subprocess.CalledProcessError as e: | ||
print(f"Error trimming video: {e}") | ||
|
||
|
||
def download_file(url, path): | ||
response = requests.get(url, stream=True) | ||
if response.status_code == 200: | ||
os.makedirs(os.path.dirname(path), exist_ok=True) # Ensure the directory exists | ||
with open(path, "wb") as file: | ||
for chunk in response.iter_content(chunk_size=8192): | ||
file.write(chunk) | ||
print(f"Downloaded {url} to {path}") | ||
else: | ||
print(f"Failed to download {url}. Status code: {response.status_code}") | ||
|
||
|
||
def is_cuda_available(): | ||
return torch.cuda.is_available() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from PIL import Image | ||
import pytest | ||
from typing import Dict, Optional | ||
|
||
from test_config import * | ||
import numpy as np | ||
from modules.paths import * | ||
from modules.constants import * | ||
from modules.sam_inference import SamInference | ||
|
||
|
||
@pytest.mark.skipif( | ||
not is_cuda_available(), | ||
reason="Skipping because this test only works in GPU" | ||
) | ||
@pytest.mark.parametrize( | ||
"model_name,image_path,points,labels,box,multimask_output", | ||
[ | ||
(TEST_MODEL, TEST_IMAGE_PATH, TEST_POINTS, TEST_LABELS, TEST_BOX, True) | ||
] | ||
) | ||
def test_image_segmentation( | ||
model_name: str, | ||
image_path: str, | ||
points: np.ndarray, | ||
labels: np.ndarray, | ||
box: np.ndarray, | ||
multimask_output: bool | ||
): | ||
download_test_files() | ||
|
||
inferencer = SamInference() | ||
print("Device:", inferencer.device) | ||
|
||
image = load_image(image_path) | ||
|
||
hparams = { | ||
"multimask_output": multimask_output, | ||
} | ||
|
||
masks, scores, logits = inferencer.predict_image( | ||
image=image, | ||
model_type=model_name, | ||
point_coords=points, | ||
point_labels=labels, | ||
**hparams | ||
) | ||
|
||
assert isinstance(masks, np.ndarray) | ||
assert isinstance(scores, np.ndarray) | ||
assert isinstance(logits, np.ndarray) | ||
|
||
|
||
def load_image(image_path): | ||
image = Image.open(image_path).convert('RGB') | ||
image_array = np.array(image) | ||
return image_array |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import pytest | ||
from typing import Dict | ||
|
||
from test_config import * | ||
import numpy as np | ||
from modules.paths import * | ||
from modules.constants import * | ||
from modules.sam_inference import SamInference | ||
|
||
|
||
@pytest.mark.skipif( | ||
not is_cuda_available(), | ||
reason="Skipping because this test only works in GPU" | ||
) | ||
@pytest.mark.parametrize( | ||
"model_name,video_path,points,labels,box", | ||
[ | ||
(TEST_MODEL, TEST_VIDEO_PATH, TEST_POINTS, TEST_LABELS, TEST_BOX) | ||
] | ||
) | ||
def test_video_segmentation( | ||
model_name: str, | ||
video_path: str, | ||
points: np.ndarray, | ||
labels: np.ndarray, | ||
box: np.ndarray | ||
): | ||
download_test_files() | ||
|
||
inferencer = SamInference() | ||
print("Device:", inferencer.device) | ||
inferencer.init_video_inference_state( | ||
vid_input=video_path, | ||
model_type=model_name, | ||
) | ||
|
||
inferencer.add_prediction_to_frame( | ||
frame_idx=0, | ||
obj_id=0, | ||
points=TEST_POINTS, | ||
labels=TEST_LABELS, | ||
) | ||
|
||
inferencer.add_prediction_to_frame( | ||
frame_idx=1, | ||
obj_id=1, | ||
box=TEST_BOX, | ||
) | ||
|
||
video_segments = inferencer.propagate_in_video() | ||
|
||
assert video_segments and isinstance(video_segments, Dict) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not is_cuda_available(), | ||
reason="Skipping because this test only works in GPU" | ||
) | ||
@pytest.mark.parametrize( | ||
"model_name,video_path,gradio_prompt", | ||
[ | ||
(TEST_MODEL, TEST_VIDEO_PATH, TEST_GRADIO_PROMPT_BOX) | ||
] | ||
) | ||
def test_filtered_video_creation_pipeline( | ||
model_name: str, | ||
video_path: str, | ||
gradio_prompt: np.ndarray, | ||
): | ||
download_test_files() | ||
|
||
inferencer = SamInference() | ||
print("Device:", inferencer.device) | ||
inferencer.init_video_inference_state( | ||
vid_input=video_path, | ||
model_type=model_name, | ||
) | ||
prompt_data = { | ||
"points": gradio_prompt | ||
} | ||
|
||
out_path, out_path = inferencer.create_filtered_video( | ||
image_prompt_input_data=prompt_data, | ||
filter_mode=COLOR_FILTER, | ||
frame_idx=0, | ||
color_hex=DEFAULT_COLOR, | ||
invert_mask=True | ||
) | ||
|
||
assert os.path.exists(out_path) |