Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM2: image pipeline #358

Merged
merged 24 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b27fa20
Add initial changes for SAM2 image sample
parthchadha Nov 8, 2024
9c9c871
Merge remote-tracking branch 'origin/main' into sam2-image
parthchadha Nov 9, 2024
cdc57d1
More improvements with sample usability
parthchadha Nov 12, 2024
f0149c7
Merge remote-tracking branch 'origin/main' into sam2-image
parthchadha Nov 12, 2024
f31b401
Add requirements.txt file
parthchadha Nov 12, 2024
423d2b3
Further clean up of sample, use fp16 everywhere except sam prompt enc…
parthchadha Nov 12, 2024
a03eaff
Batch inference for sam2 image pipeline (#382)
yizhuoz004 Nov 18, 2024
59ae33f
Remove dead code, apply model dtype to all models, reduce torch conve…
yizhuoz004 Nov 21, 2024
08069ad
Remove torch conversions in sam2_base
yizhuoz004 Nov 22, 2024
14b53ca
Include timing in demo script
parthchadha Nov 22, 2024
6324c61
Merge remote-tracking branch 'origin/main' into sam2-image
parthchadha Nov 25, 2024
1ddeb22
Include original license
parthchadha Nov 25, 2024
b71223f
Fix build error due to dtype mismatch
parthchadha Nov 26, 2024
4c7b3e3
Add testing for image demo in CI
parthchadha Nov 26, 2024
e2690ce
Use ascii char for tolerance; remove redundant forward calls in memor…
parthchadha Nov 27, 2024
79e1f01
wget images used for sample
parthchadha Nov 27, 2024
9478a24
Rearrange installation instructions
parthchadha Dec 2, 2024
0cde67d
Merge remote-tracking branch 'origin/main' into sam2-image
parthchadha Dec 2, 2024
f78c57f
Fix L1 failures
parthchadha Dec 2, 2024
ebc1aab
Remove perf logs from testing; remove dead code
parthchadha Dec 3, 2024
5de9e5f
Merge remote-tracking branch 'origin/main' into sam2-image
parthchadha Dec 3, 2024
18aa6a3
Merge remote-tracking branch 'origin/main' into sam2-image
parthchadha Dec 6, 2024
3f8e12a
Update license information
parthchadha Dec 6, 2024
dda9249
Remove use of tp.Parameter
parthchadha Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tripy/examples/segment-anything-model-v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ This is an implementation of SAM2 model ([original repository](https://github.co
python3 image_demo.py
```

<!-- Tripy: TEST: EXPECTED_STDOUT Start -->
<!--
```
Generating image embedding took {137.81±10%} ms per run (averaged over 100 runs, with 5 warmup runs)
parthchadha marked this conversation as resolved.
Show resolved Hide resolved
Predicting masks took {37.78±10%} ms per run (averaged over 100 runs, with 5 warmup runs)
Scores for each prediction: {0.78759766±5%} {0.640625±5%} {0.05099487±5%}
```
-->
<!-- Tripy: TEST: EXPECTED_STDOUT End -->

### Video segmentation pipeline

TBD
Expand Down
12 changes: 6 additions & 6 deletions tripy/examples/segment-anything-model-v2/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def process_predictions(
plt.savefig(os.path.join(save_path, f"mask_{i}_score_{score:.3f}.png"), bbox_inches="tight", pad_inches=0)
plt.close(fig)

print(f"Scores for each prediction: {' '.join(map(str, scores))}")

return {
"masks": np.array(processed_masks),
"scores": scores,
Expand Down Expand Up @@ -189,17 +191,16 @@ def time_function(func, num_warmup=5, num_runs=100, description=""):
# Warmup runs
for _ in range(num_warmup):
func()

tp.default_stream().synchronize()
torch.cuda.synchronize()
tp.default_stream().synchronize()
torch.cuda.synchronize()

# Actual timing
start = time.perf_counter()
for _ in range(num_runs):
output = func()
tp.default_stream().synchronize()
torch.cuda.synchronize()

tp.default_stream().synchronize()
torch.cuda.synchronize()
end = time.perf_counter()

avg_time_ms = (end - start) * 1000 / num_runs
Expand Down Expand Up @@ -244,7 +245,6 @@ def predict_masks():
input_labels=input_label,
save_path=save_path,
)

return results


Expand Down
48 changes: 45 additions & 3 deletions tripy/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,43 @@ def __str__(self):
return os.path.relpath(self.path, EXAMPLES_ROOT)


EXAMPLES = [Example(["nanogpt"])]
EXAMPLES = [Example(["nanogpt"]), Example(["segment-anything-model-v2"])]


@pytest.mark.l1
@pytest.mark.parametrize("example", EXAMPLES, ids=lambda case: str(case))
def test_examples(example, sandboxed_install_run):

def test_with_tolerance(number, value, tolerance):
parthchadha marked this conversation as resolved.
Show resolved Hide resolved
tolerance = float(tolerance) / 100
lower = float(number) * (1 - tolerance)
upper = float(number) * (1 + tolerance)
parthchadha marked this conversation as resolved.
Show resolved Hide resolved
try:
return lower <= float(value) <= upper
except ValueError:
return False
parthchadha marked this conversation as resolved.
Show resolved Hide resolved

def process_tolerances(expected_output):
specs = []
placeholder_regex = r"{(\d+\.?\d*)±(\d+)%}"
pattern = expected_output

# Replace tolerance patterns with more flexible capture group
matches = list(re.finditer(placeholder_regex, pattern))
for match in matches:
specs.append((match.group(1), match.group(2)))
pattern = pattern.replace(match.group(0), r"(\d+\.?\d*)", 1)

# Escape parentheses but not our capture group
pattern = pattern.replace("(", r"\(")
pattern = pattern.replace(")", r"\)")
pattern = pattern.replace(r"\(\d+\.?\d*\)", r"(\d+\.?\d*)")

# Make whitespace flexible
pattern = pattern.replace(" ", r"\s+")
parthchadha marked this conversation as resolved.
Show resolved Hide resolved

return pattern.strip(), specs

with open(example.readme, "r", encoding="utf-8") as f:
contents = f.read()
# Check that the README has all the expected sections.
Expand All @@ -101,9 +132,20 @@ def test_examples(example, sandboxed_install_run):

code = str(block)
if block.has_marker("test: expected_stdout"):
print("Checking command output against expected output: ", end="")
out = statuses[-1].stdout.strip()
matched = re.match(dedent(code).strip(), out)
expected = dedent(code).strip()
pattern, specs = process_tolerances(expected)

match = re.search(pattern, out)
if match and specs:
# Check if captured numbers are within tolerance
matched = all(
test_with_tolerance(expected, actual, tolerance)
for (expected, tolerance), actual in zip(specs, match.groups())
)
else:
matched = bool(match)

print("matched!" if matched else "did not match!")
print(f"==== STDOUT ====\n{out}")
assert matched
Expand Down