Skip to content

Commit

Permalink
feat: add sam 2 pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
suhailkakar committed Oct 9, 2024
1 parent 9e698b2 commit 82788bd
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 46 deletions.
24 changes: 22 additions & 2 deletions packages/www/components/ModelGallery/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,25 @@ const audioToTextInputs: Input[] = [
},
];

const segmentationInputs: Input[] = [
{
id: "image",
name: "Image",
type: "segment_file",
required: true,
description: "The image to segment",
group: "prompt",
},
{
id: "box",
name: "Box",
type: "",
required: true,
description: "A length 4 array given as a box prompt [x1, y1, x2, y2]",
group: "prompt",
},
];

const availableModels: Model[] = [
{
id: "RealVisXL_V4.0_Lightning",
Expand Down Expand Up @@ -248,7 +267,7 @@ const availableModels: Model[] = [
pipeline: "Segmentation",
image: "sam2-hiera-large.png",
huggingFaceId: "facebook/sam2-hiera-large",
inputs: audioToTextInputs,
inputs: segmentationInputs,
},
];

Expand All @@ -267,8 +286,9 @@ type Input = {
id: string;
name: string;
type: string;
defaultValue?: string | number;
defaultValue?: string | number | boolean;
required: boolean;
disabled?: boolean;
description: string;
group: string;
};
Expand Down
3 changes: 2 additions & 1 deletion packages/www/css/tailwind.css
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
--input: 240 5.9% 90%;
--ring: 240 5% 64.9%;
--radius: 0.5rem;

--loader-background: 240 3.7% 90.9%;
--chart-1: 12 76% 61%;
--chart-2: 173 58% 39%;
--chart-3: 197 37% 24%;
Expand Down Expand Up @@ -52,6 +52,7 @@
--border: 240 3.7% 15.9%;
--input: 240 3.7% 15.9%;
--ring: 240 4.9% 83.9%;
--loader-background: 240 3.7% 15.9%;

--chart-1: 220 70% 50%;
--chart-2: 160 60% 45%;
Expand Down
11 changes: 11 additions & 0 deletions packages/www/hooks/use-api/endpoints/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ export const audioToText = async (formData: any) => {

return [text];
};

export const segmentImage = async (formData: any) => {
console.log("formData", formData);
const url = `/beta/generate/segment-anything-2`;
const [res, image] = await context.fetch(url, {
method: "POST",
body: formData,
});

return image;
};
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useRef, useState } from "react";
import React, { useEffect, useRef, useState } from "react";
import type {
Model as ModelT,
Output,
Expand All @@ -8,20 +8,29 @@ import { Button } from "components/ui/button";
import { Textarea } from "components/ui/textarea";
import { Input } from "components/ui/input";
import { useApi } from "hooks";
import { Loader2 } from "lucide-react";
import { Loader2, X } from "lucide-react";

export default function Form({
model,
setOutput,
setGenerationTime,
loading,
setLoading,
}: {
model: ModelT;
setOutput: (output: Output[]) => void;
setGenerationTime: (time: number) => void;
loading: boolean;
setLoading: (loading: boolean) => void;
}) {
const { textToImage, upscale, imageToVideo, imageToImage, audioToText } =
useApi();
const [loading, setLoading] = useState<boolean>(false);
const {
textToImage,
upscale,
imageToVideo,
imageToImage,
audioToText,
segmentImage,
} = useApi();
const startTimeRef = useRef<number | null>(null);
const timerRef = useRef<NodeJS.Timeout | null>(null);

Expand Down Expand Up @@ -50,7 +59,7 @@ export default function Form({
}
}, 100);

console.log(model?.pipeline);
console.log(formData);
switch (model?.pipeline) {
case "Text to Image":
const textToImageRes = await textToImage(formInputs);
Expand All @@ -72,6 +81,10 @@ export default function Form({
const audioToTextRes = await audioToText(formData);
setOutput(audioToTextRes);
break;
case "Segmentation":
const segmentImageRes = await segmentImage(formData);
setOutput(segmentImageRes.images);
break;
case "image-to-image":
break;
}
Expand Down Expand Up @@ -115,7 +128,7 @@ export default function Form({
.map((input) => (
<div key={input.id}>
<Label>{input.name}</Label>
<div className="mt-1">{renderInput(input)}</div>
<div className="mt-1">{renderInput(input, formRef)}</div>
</div>
))}
</fieldset>
Expand Down Expand Up @@ -148,7 +161,11 @@ export default function Form({
);
}

const renderInput = (input: any) => {
const renderInput = (input: any, formRef: React.RefObject<HTMLFormElement>) => {
const [selectedSegmentImage, setSelectedSegmentImage] = useState<
string | null
>(null);

switch (input.type) {
case "textarea":
return (
Expand Down Expand Up @@ -178,13 +195,177 @@ const renderInput = (input: any) => {
type="file"
/>
);
case "segment_file":
return <SegmentInput formRef={formRef} input={input} />;
default:
return (
<Input
name={input.id}
placeholder={input.description}
required={input.required}
disabled={input.disabled}
type="text"
defaultValue={input.defaultValue}
/>
);
}
};

const SegmentInput = ({
formRef,
input,
}: {
formRef: React.RefObject<HTMLFormElement>;
input: any;
}) => {
const [selectedSegmentImage, setSelectedSegmentImage] = useState<
string | null
>(null);
const canvasRef = useRef<HTMLCanvasElement | null>(null);
const [isDrawing, setIsDrawing] = useState(false);
const [box, setBox] = useState<{
startX: number;
startY: number;
endX: number;
endY: number;
} | null>(null);
const [imageDimensions, setImageDimensions] = useState<{
width: number;
height: number;
} | null>(null);

useEffect(() => {
if (selectedSegmentImage && canvasRef.current) {
const image = new Image();
image.src = selectedSegmentImage;

Check warning

Code scanning / CodeQL

DOM text reinterpreted as HTML Medium

DOM text
is reinterpreted as HTML without escaping meta-characters.
image.onload = () => {
const canvas = canvasRef.current!;
const ctx = canvas.getContext("2d");
if (ctx) {
const aspectRatio = image.height / image.width;
canvas.width = 400; // Set a fixed width
canvas.height = canvas.width * aspectRatio;
setImageDimensions({ width: canvas.width, height: canvas.height });

ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(image, 0, 0, canvas.width, canvas.height);
}
};
}
}, [selectedSegmentImage]);

const handleMouseDown = (e: React.MouseEvent<HTMLCanvasElement>) => {
if (!selectedSegmentImage) return;
const rect = e.currentTarget.getBoundingClientRect();
const startX =
(e.clientX - rect.left) * (canvasRef.current!.width / rect.width);
const startY =
(e.clientY - rect.top) * (canvasRef.current!.height / rect.height);
setBox({ startX, startY, endX: startX, endY: startY });
setIsDrawing(true);
};

const handleMouseMove = (e: React.MouseEvent<HTMLCanvasElement>) => {
if (!isDrawing || !box) return;
const rect = e.currentTarget.getBoundingClientRect();
const endX =
(e.clientX - rect.left) * (canvasRef.current!.width / rect.width);
const endY =
(e.clientY - rect.top) * (canvasRef.current!.height / rect.height);
setBox((prevBox) => ({ ...prevBox!, endX, endY }));
drawBox();
};

const handleMouseUp = () => {
if (isDrawing && box) {
const boxValue = JSON.stringify([
Math.round(box.startX),
Math.round(box.startY),
Math.round(box.endX),
Math.round(box.endY),
]);
}
setIsDrawing(false);
};

const drawBox = () => {
if (canvasRef.current && box && selectedSegmentImage) {
const ctx = canvasRef.current.getContext("2d");
if (ctx && imageDimensions) {
const image = new Image();
image.src = selectedSegmentImage;

Check warning

Code scanning / CodeQL

DOM text reinterpreted as HTML Medium

DOM text
is reinterpreted as HTML without escaping meta-characters.
image.style.borderRadius = "15px";
image.onload = () => {
ctx.clearRect(0, 0, imageDimensions.width, imageDimensions.height);
ctx.drawImage(
image,
0,
0,
imageDimensions.width,
imageDimensions.height
);
ctx.strokeStyle = "red";
ctx.lineWidth = 2;
ctx.strokeRect(
box.startX,
box.startY,
box.endX - box.startX,
box.endY - box.startY
);
};
}
}
};

return (
<>
<Input
name={input.id}
placeholder={input.description}
required={input.required}
type="file"
onChange={(e) => {
const file = e.target.files?.[0];
if (file) {
setSelectedSegmentImage(URL.createObjectURL(file));
} else {
setSelectedSegmentImage(null);
}
}}
/>
{selectedSegmentImage && (
<div className="mt-4">
<Label className="font-normal">
Draw a box around the object you want to segment
</Label>
<div className="relative mt-1">
<canvas
ref={canvasRef}
className="w-full object-contain rounded-md border-2 border-input"
style={{ maxWidth: "400px" }}
onMouseDown={handleMouseDown}
onMouseMove={handleMouseMove}
onMouseUp={handleMouseUp}
/>
<button
className="absolute top-2 right-2 bg-white/50 rounded-full p-2 shadow"
onClick={() => {
setSelectedSegmentImage(null);
if (formRef.current) {
const inputElement = formRef.current.elements.namedItem(
input.id
) as HTMLInputElement;
if (inputElement) {
inputElement.value = "";
}
}
}}>
<X className="h-4 w-4 text-muted-foreground" />
</button>
</div>
<input type="hidden" name="box" />
</div>
)}
</>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export default function PlaygroundPage() {
const model = availableModels.find((model) => model.id === id);

const [output, setOutput] = useState<OutputT[]>([]);
const [loading, setLoading] = useState(false);

if (!user) {
return <Layout />;
Expand All @@ -64,13 +65,16 @@ export default function PlaygroundPage() {
<Form
model={model}
setOutput={setOutput}
setLoading={setLoading}
loading={loading}
setGenerationTime={setGenerationTime}
/>
</div>
<div className="md:w-[70%]">
<Output
model={model}
output={output}
loading={loading}
generationTime={generationTime}
/>
</div>
Expand Down
Loading

0 comments on commit 82788bd

Please sign in to comment.