diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java index 4299af02ba3..d67dc9e4b6d 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java @@ -12,6 +12,9 @@ */ package ai.djl.modality.cv.translator; +import ai.djl.Model; +import ai.djl.ModelException; +import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.BoundingBox; @@ -27,15 +30,21 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.translate.ArgumentsUtil; import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.NoopTranslator; import ai.djl.translate.Pipeline; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import java.io.IOException; +import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.UUID; /** A {@link Translator} that handles mask generation task. */ public class Sam2Translator implements NoBatchifyTranslator { @@ -44,13 +53,38 @@ public class Sam2Translator implements NoBatchifyTranslator predictor; + private String encoderPath; /** Constructs a {@code Sam2Translator} instance. */ - public Sam2Translator() { + public Sam2Translator(Builder builder) { pipeline = new Pipeline(); pipeline.add(new Resize(1024, 1024)); pipeline.add(new ToTensor()); pipeline.add(new Normalize(MEAN, STD)); + this.encoderPath = builder.encoderPath; + } + + /** {@inheritDoc} */ + @Override + public void prepare(TranslatorContext ctx) throws IOException, ModelException { + if (encoderPath == null) { + return; + } + Model model = ctx.getModel(); + Path path = Paths.get(encoderPath); + if (!path.isAbsolute() && Files.notExists(path)) { + path = model.getModelPath().resolve(encoderPath); + } + if (!Files.exists(path)) { + throw new IOException("encoder model not found: " + encoderPath); + } + NDManager manager = ctx.getNDManager(); + Model encoder = manager.getEngine().newModel("encoder", manager.getDevice()); + encoder.load(path); + predictor = encoder.newPredictor(new NoopTranslator(null)); + model.getNDManager().attachInternal(UUID.randomUUID().toString(), predictor); + model.getNDManager().attachInternal(UUID.randomUUID().toString(), encoder); } /** {@inheritDoc} */ @@ -72,7 +106,21 @@ public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Except NDArray locations = manager.create(buf, new Shape(1, numPoints, 2)); NDArray labels = manager.create(input.getLabels()); - return new NDList(array, locations, labels); + if (predictor == null) { + return new NDList(array, locations, labels); + } + + NDList embeddings = predictor.predict(new NDList(array)); + NDArray mask = manager.zeros(new Shape(1, 1, 256, 256)); + NDArray hasMask = manager.zeros(new Shape(1)); + return new NDList( + embeddings.get(2), + embeddings.get(0), + embeddings.get(1), + locations, + labels, + mask, + hasMask); } /** {@inheritDoc} */ @@ -101,6 +149,55 @@ public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws return new DetectedObjects(classes, probabilities, boxes); } + /** + * Creates a builder to build a {@code Sam2Translator}. + * + * @return a new builder + */ + public static Builder builder() { + return builder(Collections.emptyMap()); + } + + /** + * Creates a builder to build a {@code Sam2Translator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static Builder builder(Map arguments) { + return new Builder(arguments); + } + + /** The builder for Sam2Translator. */ + public static class Builder { + + String encoderPath; + + Builder(Map arguments) { + encoderPath = ArgumentsUtil.stringValue(arguments, "encoder"); + } + + /** + * Sets the encoder model path. + * + * @param encoderPath the encoder model path + * @return the builder + */ + public Builder optEncoderPath(String encoderPath) { + this.encoderPath = encoderPath; + return this; + } + + /** + * Builds the translator. + * + * @return the new translator + */ + public Sam2Translator build() { + return new Sam2Translator(this); + } + } + /** A class represents the segment anything input. */ public static final class Sam2Input { @@ -149,8 +246,12 @@ float[] toLocationArray(int width, int height) { return ret; } - int[][] getLabels() { - return new int[][] {labels.stream().mapToInt(Integer::intValue).toArray()}; + float[][] getLabels() { + float[][] buf = new float[1][labels.size()]; + for (int i = 0; i < labels.size(); ++i) { + buf[0][i] = labels.get(i); + } + return buf; } /** diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java index 82fd8c6f6bb..299b4b19b18 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2TranslatorFactory.java @@ -13,8 +13,6 @@ package ai.djl.modality.cv.translator; import ai.djl.Model; -import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.output.CategoryMask; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; import ai.djl.translate.Translator; @@ -43,8 +41,8 @@ public class Sam2TranslatorFactory implements TranslatorFactory, Serializable { @SuppressWarnings("unchecked") public Translator newInstance( Class input, Class output, Model model, Map arguments) { - if (input == Image.class && output == CategoryMask.class) { - return (Translator) new Sam2Translator(); + if (input == Sam2Input.class && output == DetectedObjects.class) { + return (Translator) Sam2Translator.builder(arguments).build(); } throw new IllegalArgumentException("Unsupported input/output types."); } diff --git a/examples/docs/trace_sam2_img.py b/examples/docs/trace_sam2_img.py index bb80795202b..ce5a94ed0ce 100644 --- a/examples/docs/trace_sam2_img.py +++ b/examples/docs/trace_sam2_img.py @@ -10,8 +10,9 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import os import sys -from typing import Tuple +from typing import Any import torch from sam2.modeling.sam2_base import SAM2Base @@ -19,77 +20,129 @@ from torch import nn -class Sam2Wrapper(nn.Module): +class SAM2ImageEncoder(nn.Module): - def __init__( - self, - sam_model: SAM2Base, - ) -> None: + def __init__(self, sam_model: SAM2Base) -> None: super().__init__() self.model = sam_model + self.image_encoder = sam_model.image_encoder + self.no_mem_embed = sam_model.no_mem_embed - # Spatial dim for backbone feature maps - self._bb_feat_sizes = [ - (256, 256), - (128, 128), - (64, 64), - ] + def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: + backbone_out = self.image_encoder(x) + backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1]) - def extract_features( - self, - input_image: torch.Tensor, - ) -> (torch.Tensor, torch.Tensor, torch.Tensor): - backbone_out = self.model.forward_image(input_image) - _, vision_feats, _, _ = self.model._prepare_backbone_features( - backbone_out) - # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos - if self.model.directly_add_no_mem_embed: - vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feature_maps = backbone_out["backbone_fpn"][-self.model. + num_feature_levels:] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model. + num_feature_levels:] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_feats[-1] = vision_feats[-1] + self.no_mem_embed feats = [ - feat.permute(1, 2, - 0).view(1, -1, *feat_size) for feat, feat_size in zip( - vision_feats[::-1], self._bb_feat_sizes[::-1]) + feat.permute(1, 2, 0).reshape(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1]) ][::-1] - return feats[-1], feats[0], feats[1] + return feats[0], feats[1], feats[2] - def forward( - self, - input_image: torch.Tensor, - point_coords: torch.Tensor, - point_labels: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - image_embed, feature_1, feature_2 = self.extract_features(input_image) - return self.predict(point_coords, point_labels, image_embed, feature_1, - feature_2) - def predict( +class SAM2ImageDecoder(nn.Module): + + def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: + super().__init__() + self.mask_decoder = sam_model.sam_mask_decoder + self.prompt_encoder = sam_model.sam_prompt_encoder + self.model = sam_model + self.img_size = sam_model.image_size + self.multimask_output = multimask_output + self.sparse_embedding = None + + @torch.no_grad() + def forward( self, + image_embed: torch.Tensor, + high_res_feats_0: torch.Tensor, + high_res_feats_1: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, - image_embed: torch.Tensor, - feats_1: torch.Tensor, - feats_2: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - concat_points = (point_coords, point_labels) - - sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( - points=concat_points, - boxes=None, - masks=None, - ) - - low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( - image_embeddings=image_embed[0].unsqueeze(0), - image_pe=self.model.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=True, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + self.sparse_embedding = sparse_embedding + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + high_res_feats = [high_res_feats_0, high_res_feats_1] + image_embed = image_embed + + masks, iou_predictions, _, _ = self.mask_decoder.predict_masks( + image_embeddings=image_embed, + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, repeat_image=False, - high_res_features=[feats_1, feats_2], + high_res_features=high_res_feats, ) - return low_res_masks, iou_predictions + + if self.multimask_output: + masks = masks[:, 1:, :, :] + iou_predictions = iou_predictions[:, 1:] + else: + masks, iou_pred = ( + self.mask_decoder._dynamic_multimask_via_stability( + masks, iou_predictions)) + + masks = torch.clamp(masks, -32.0, 32.0) + + return masks, iou_predictions + + def _embed_points(self, point_coords: torch.Tensor, + point_labels: torch.Tensor) -> torch.Tensor: + + point_coords = point_coords + 0.5 + + padding_point = torch.zeros((point_coords.shape[0], 1, 2), + device=point_coords.device) + padding_label = -torch.ones( + (point_labels.shape[0], 1), device=point_labels.device) + point_coords = torch.cat([point_coords, padding_point], dim=1) + point_labels = torch.cat([point_labels, padding_label], dim=1) + + point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size + point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size + + point_embedding = self.prompt_encoder.pe_layer._pe_encoding( + point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = (point_embedding + + self.prompt_encoder.not_a_point_embed.weight * + (point_labels == -1)) + + for i in range(self.prompt_encoder.num_point_embeddings): + point_embedding = (point_embedding + + self.prompt_encoder.point_embeddings[i].weight * + (point_labels == i)) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, + has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling( + input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding def trace_model(model_id: str): @@ -98,19 +151,47 @@ def trace_model(model_id: str): else: device = torch.device("cpu") + model_name = f"{model_id[9:]}" + os.makedirs(model_name) + predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device) - model = Sam2Wrapper(predictor.model) + encoder = SAM2ImageEncoder(predictor.model) + decoder = SAM2ImageDecoder(predictor.model, True) input_image = torch.ones(1, 3, 1024, 1024).to(device) - input_point = torch.ones(1, 1, 2).to(device) - input_labels = torch.ones(1, 1, dtype=torch.int32, device=device) - - converted = torch.jit.trace_module( - model, { - "extract_features": input_image, - "forward": (input_image, input_point, input_labels) - }) - torch.jit.save(converted, f"{model_id[9:]}.pt") + high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image) + + converted = torch.jit.trace(encoder, input_image) + torch.jit.save(converted, f"model_name/encoder.pt") + + # trace decoder model + embed_size = ( + predictor.model.image_size // predictor.model.backbone_stride, + predictor.model.image_size // predictor.model.backbone_stride, + ) + mask_input_size = [4 * x for x in embed_size] + + point_coords = torch.randint(low=0, + high=1024, + size=(1, 5, 2), + dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float) + mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float) + has_mask_input = torch.tensor([1], dtype=torch.float) + + converted = torch.jit.trace( + decoder, (image_embed, high_res_feats_0, high_res_feats_1, + point_coords, point_labels, mask_input, has_mask_input)) + torch.jit.save(converted, f"model_name/model_name.pt") + + # save serving.properties + serving_file = os.path.join(model_name, "serving.properties") + with open(serving_file, "w") as f: + f.write( + f"engine=PyTorch\n" + f"option.modelName={model_name}\n" + f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n" + f"encoder=encoder.pt") if __name__ == '__main__': diff --git a/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java b/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java index 8fff3a94127..85afcb8789f 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java +++ b/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java @@ -17,8 +17,8 @@ import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.DetectedObjects; -import ai.djl.modality.cv.translator.Sam2Translator; import ai.djl.modality.cv.translator.Sam2Translator.Sam2Input; +import ai.djl.modality.cv.translator.Sam2TranslatorFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; import ai.djl.training.util.ProgressBar; @@ -54,7 +54,7 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran .optModelUrls("djl://ai.djl.pytorch/sam2-hiera-tiny") .optEngine("PyTorch") .optDevice(Device.cpu()) // use sam2-hiera-tiny-gpu for GPU - .optTranslator(new Sam2Translator()) + .optTranslatorFactory(new Sam2TranslatorFactory()) .optProgress(new ProgressBar()) .build(); diff --git a/examples/src/test/java/ai/djl/examples/inference/cv/SegmentAnything2Test.java b/examples/src/test/java/ai/djl/examples/inference/cv/SegmentAnything2Test.java index 705832c4b23..83e8e93131f 100644 --- a/examples/src/test/java/ai/djl/examples/inference/cv/SegmentAnything2Test.java +++ b/examples/src/test/java/ai/djl/examples/inference/cv/SegmentAnything2Test.java @@ -25,7 +25,7 @@ public class SegmentAnything2Test { @Test - public void testInstanceSegmentation() throws ModelException, TranslateException, IOException { + public void testSam2() throws ModelException, TranslateException, IOException { DetectedObjects result = SegmentAnything2.predict(); Classifications.Classification best = result.best(); Assert.assertTrue(Double.compare(best.getProbability(), 0.3) > 0);