Skip to content

❓ [Question] Is SAM2 supported when compiling with the Dynamo backend on JetPack 6.1 or 6.2? #3478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
AyanamiReiFan opened this issue Apr 17, 2025 · 14 comments
Assignees
Labels
question Further information is requested

Comments

@AyanamiReiFan
Copy link

AyanamiReiFan commented Apr 17, 2025

❓ Question

Will SAM2 be compatible with the Dynamo backend on JetPack 6.1/6.2?

Are there any workarounds for the TensorRT version mismatch?

What you have already tried

Here are my attempts and issues encountered, my device is jetson AGX Orin, I only compile the ImageEncoder (Hiera & FPN which remove position_encoding) of SAM2, the SAM2 code is from https://github.com/chohk88/sam2/tree/torch-trt:

JetPack 6.1 + PyTorch 2.5 (from https://developer.download.nvidia.cn) + Torch-TensorRT 2.5

Tried compiling SAM2 but encountered errors.

Observed that the PyTorch 2.5 documentation does not mention SAM2 support, likely indicating SAM2 is not yet adapted for this version.

JetPack 6.1 + PyTorch 2.6 (from https://pypi.jetson-ai-lab.dev/jp6/cu126) + Torch-TensorRT 2.6

Installed PyTorch 2.6 from jp6/cu126 and Torch-TensorRT 2.6.

Importing torch_tensorrt failed with ModuleNotFoundError: No module named 'tensorrt.plugin'.

Root cause: Torch-TensorRT 2.6 requires TensorRT 10.7, but JetPack 6.1 provides only TensorRT 10.3.

Found no straightforward way to upgrade TensorRT within JetPack 6.1 due to dependency conflicts.

Cross-Platform Attempt: Compile on x86 + Run on JetPack 6.1

Compiled SAM2 on x86 with Torch-TensorRT 2.6 and exported the model.

Tried running it on JetPack 6.1 with Torch-TensorRT 2.5.

Failed unsurprisingly due to serialization version incompatibility between 2.6 and 2.5.

@AyanamiReiFan AyanamiReiFan added the question Further information is requested label Apr 17, 2025
@narendasan
Copy link
Collaborator

cc @peri044 @chohk88

@peri044
Copy link
Collaborator

peri044 commented Apr 22, 2025

@AyanamiReiFan I don't know of any workarounds for upgrading TRT 10.3 on Jetpack. That being said, you could give 25.03-py3-igpu a container a try. This container has TRT 10.9 and the corresponding Torch-TRT version. This might work although I haven't tested this yet. In the future, Jetpack 7 will have TRT 10.6+ which could also fix this issue.

@narendasan
Copy link
Collaborator

The iGPU container should also have a much more recent version of Torch-TRT

@AyanamiReiFan
Copy link
Author

@AyanamiReiFan I don't know of any workarounds for upgrading TRT 10.3 on Jetpack. That being said, you could give 25.03-py3-igpu a container a try. This container has TRT 10.9 and the corresponding Torch-TRT version. This might work although I haven't tested this yet. In the future, Jetpack 7 will have TRT 10.6+ which could also fix this issue.

Thanks very much! I will try it later.

@lanluo-nvidia
Copy link
Collaborator

@AyanamiReiFan
I have fixed the No module named 'tensorrt.plugin'
Here is the PR merged to main: #3518
You can Follow the guide on this branch https://github.com/pytorch/TensorRT/blob/b0baba1d9687ad8a8f1db577abd029e38e3555af/docsrc/getting_started/jetpack.rst

meanwhile I will also give a try for SAM2 on jetson.

@lanluo-nvidia
Copy link
Collaborator

lanluo-nvidia commented Jun 1, 2025

@AyanamiReiFan
I have successfully build and install sam2 in jetson orin.

since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet.
please follow the jetpack guide here for building from lastest:
https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst

@AyanamiReiFan
Copy link
Author

@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.

since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst

Thank you very much. I've read this JetPack guide. Can I assume that this guide essentially installs PyTorch 2.7 and the not-yet-officially-released Torch-TensorRT 2.8 on JetPack 6.2-based Jetson devices? If so, once Torch-TensorRT 2.8 is officially released, will it be properly compatible with JetPack 6.2 Jetson devices?

Additionally, I’m currently trying to follow these installation steps and will provide feedback as soon as possible.

@AyanamiReiFan
Copy link
Author

AyanamiReiFan commented Jun 2, 2025

@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.

since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst

I build troch-tensorrt from the main branch.
When execute the comand:
python setup.py bdist_wheel --jetpack
I meet this:
Loading: 0 packages loaded
the wheel buid failed.

detail log:

xxxxxxx@ubuntu:~/Develops/env_prepare/TensorRT$ python setup.py bdist_wheel --jetpack
2025/06/02 21:55:43 Downloading https://releases.bazel.build/8.1.1/release/bazel-8.1.1-linux-arm64...
Extracting Bazel installation...
Starting local Bazel server (8.1.1) and connecting to it...
no actions running
no actions running
no actions running
no actions running
no actions running
DEBUG: Rule 'rules_pkg+' indicated that a canonical reproducible form can be obtained by modifying arguments commit = "17c57f46e5c7cd58f893d7960b4fe6fe59bb77b1"
DEBUG: Repository rules_pkg+ instantiated at:
: in
Repository rule git_repository defined at:
/home/xxxxxxx/.cache/bazel/_bazel_xxxxxxx/72b5e5ef13d3b4a46b0f6eef458b108c/external/bazel_tools/tools/build_defs/repo/git.bzl:193:33: in
no actions running
no actions running
no actions running
no actions running
no actions running
no actions running
no actions running
no actions running
no actions running
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
DEBUG: Rule '+_repo_rules3+tensorrt_l4t' indicated that a canonical reproducible form can be obtained by modifying arguments integrity = "sha256-Nefnvwprjho0KdnSvNv2zjDURoPV14q6ufzFK/vpY4o="
DEBUG: Repository +_repo_rules3+tensorrt_l4t instantiated at:
: in
Repository rule http_archive defined at:
/home/xxxxxxx/.cache/bazel/_bazel_xxxxxxx/72b5e5ef13d3b4a46b0f6eef458b108c/external/bazel_tools/tools/build_defs/repo/http.bzl:392:31: in
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded

@AyanamiReiFan
Copy link
Author

AyanamiReiFan commented Jun 2, 2025

@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.

since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst

I have identified the reason for the "Loading: 0 packages loaded" issue mentioned above. When configuring the system proxy on Ubuntu, I forgot to set no_proxy, causing traffic that should have gone to the Bazel server to be incorrectly routed to the proxy server.

@AyanamiReiFan
Copy link
Author

@AyanamiReiFan I have successfully build and install sam2 in jetson orgin.

since https://github.com/pytorch/TensorRT/pull/3524/files has not been merged to main yet. please follow the jetpack guide here for building from lastest: https://github.com/pytorch/TensorRT/blob/61b3480b1576b7b445ce883003e2d2aff5795610/docsrc/getting_started/jetpack.rst

I got this error when install the whl, the torch2.7 is not supported by main branch?

xxxxxxx@ubuntu:~/Develops/env_prepare/TensorRT/dist$ python -m pip install torch_tensorrt-2.8.0.dev0+727cbd2e9-cp310-cp310-linux_aarch64.whl
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
Processing ./torch_tensorrt-2.8.0.dev0+727cbd2e9-cp310-cp310-linux_aarch64.whl
INFO: pip is looking at multiple versions of torch-tensorrt to determine which version is compatible with other requirements. This could take a while.
ERROR: Could not find a version that satisfies the requirement torch<2.9.0,>=2.8.0.dev (from torch-tensorrt) (from versions: 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0, 2.7.0)
ERROR: No matching distribution found for torch<2.9.0,>=2.8.0.dev

@lanluo-nvidia lanluo-nvidia self-assigned this Jun 2, 2025
@lanluo-nvidia
Copy link
Collaborator

@AyanamiReiFan
Did you use the latest code from this branch:
https://github.com/pytorch/TensorRT/tree/lluo/jetson_build

in the pyproject, if your environment is tegra, it should pull
"torch>=2.7.0,<2.8.0; 'tegra' in platform_release"

"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",

@AyanamiReiFan
Copy link
Author

AyanamiReiFan commented Jun 2, 2025

@AyanamiReiFan Did you use the latest code from this branch: https://github.com/pytorch/TensorRT/tree/lluo/jetson_build

in the pyproject, if your environment is tegra, it should pull "torch>=2.7.0,<2.8.0; 'tegra' in platform_release"

TensorRT/pyproject.toml

Line 13 in 61b3480

"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",

I use the main branch, not this branch.
I add --no-deps and it seems work.
the script examples/dynamo/torch_compile_resnet_example.py seems run correctly.
I will try rum sam2 in both main branch and your branch later.

@AyanamiReiFan
Copy link
Author

AyanamiReiFan commented Jun 2, 2025

@AyanamiReiFan Did you use the latest code from this branch: https://github.com/pytorch/TensorRT/tree/lluo/jetson_build

in the pyproject, if your environment is tegra, it should pull "torch>=2.7.0,<2.8.0; 'tegra' in platform_release"

TensorRT/pyproject.toml

Line 13 in 61b3480

"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",

When using your branch, I could compile and run SAM2-Large, but the image quality is significantly worse than in the examples. Additionally, when attempting to use base+ or tiny models, compilation fails with errors:

utils.cpp:2468: CHECK(output_shape.size() == rep_vector.size()) failed. 
ERROR:torch_tensorrt [TensorRT Conversion Context]:Error Code: 9: Skipping tactic 0x0000000000000000 due to exception No Myelin Error exists
utils.cpp:2468: CHECK(output_shape.size() == rep_vector.size()) failed. 
ERROR:torch_tensorrt [TensorRT Conversion Context]:Error Code: 9: Skipping tactic 0x0000000000000000 due to exception No Myelin Error exists
utils.cpp:2468: CHECK(output_shape.size() == rep_vector.size()) failed. 
ERROR:torch_tensorrt [TensorRT Conversion Context]:Error Code: 9: Skipping tactic 0x0000000000000000 due to exception No Myelin Error exists
ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[[SLICE]-[aten_ops.slice.Tensor]-[/slice_18]...[SHUFFLE]-[aten_ops.squeeze.dim]-[/squeeze_63] + [SHUFFLE]-[aten_ops._reshape_copy.default]-[/_reshape_copy_424] + [SHUFFLE]-[aten_ops.permute.default]-[/permute_246]]}.)
Traceback (most recent call last):
  File "/home/xxxxxxx/Develops/sam2-torch-trt/export_sam2.py", line 257, in <module>
    trt_model = torch_tensorrt.dynamo.compile(
  File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 712, in compile
    trt_gm = compile_module(
  File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 918, in compile_module
    trt_module = convert_module(
  File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 90, in convert_module
    interpreter_result = interpret_module_to_result(
  File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 69, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/home/xxxxxxx/.local/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 739, in run
    assert serialized_engine
AssertionError

The experimental code I used is from there and I remove the line 38:
from sam_components import SAM2FullModel

Resulting images:

Image
Image
Image

@AyanamiReiFan
Copy link
Author

AyanamiReiFan commented Jun 3, 2025

@AyanamiReiFan Did you use the latest code from this branch: https://github.com/pytorch/TensorRT/tree/lluo/jetson_build

in the pyproject, if your environment is tegra, it should pull "torch>=2.7.0,<2.8.0; 'tegra' in platform_release"

TensorRT/pyproject.toml

Line 13 in 61b3480

"torch>=2.8.0.dev,<2.9.0; 'tegra' not in platform_release",

After making some modifications, I was able to successfully compile Hiera-Tiny models (since I only need the Hiera part of SAM2). However, I'm not entirely sure which specific change(s) in my modifications actually resolved the issue. Below are the modification details—I hope this might be helpful for your development work.

  • I modified hieradet.py to make the pos_embed been cached during model initialization
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
from functools import partial
from typing import List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from iopath.common.file_io import g_pathmgr

from sam2.modeling.backbones.utils import (
    PatchEmbed,
    window_partition,
    window_unpartition,
)

from sam2.modeling.sam2_utils import DropPath, MLP


def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
    if pool is None:
        return x
    # (B, H, W, C) -> (B, C, H, W)
    x = x.permute(0, 3, 1, 2)
    x = pool(x)
    # (B, C, H', W') -> (B, H', W', C)
    x = x.permute(0, 2, 3, 1)
    if norm:
        x = norm(x)

    return x


class MultiScaleAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: int,
        num_heads: int,
        q_pool: nn.Module = None,
    ):
        super().__init__()

        self.dim = dim
        self.dim_out = dim_out
        self.num_heads = num_heads
        self.q_pool = q_pool
        self.qkv = nn.Linear(dim, dim_out * 3)
        self.proj = nn.Linear(dim_out, dim_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape
        # qkv with shape (B, H * W, 3, nHead, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
        # q, k, v with shape (B, H * W, nheads, C)
        q, k, v = torch.unbind(qkv, 2)

        # Q pooling (for downsample at stage changes)
        if self.q_pool:
            q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
            H, W = q.shape[1:3]  # downsampled shape
            q = q.reshape(B, H * W, self.num_heads, -1)

        # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
        x = F.scaled_dot_product_attention(
            q.transpose(1, 2),
            k.transpose(1, 2),
            v.transpose(1, 2),
        )
        # Transpose back
        x = x.transpose(1, 2)
        x = x.reshape(B, H, W, -1)

        x = self.proj(x)

        return x


class MultiScaleBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        drop_path: float = 0.0,
        norm_layer: Union[nn.Module, str] = "LayerNorm",
        q_stride: Tuple[int, int] = None,
        act_layer: nn.Module = nn.GELU,
        window_size: int = 0,
    ):
        super().__init__()

        if isinstance(norm_layer, str):
            norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)

        self.dim = dim
        self.dim_out = dim_out
        self.norm1 = norm_layer(dim)

        self.window_size = window_size

        self.pool, self.q_stride = None, q_stride
        if self.q_stride:
            self.pool = nn.MaxPool2d(
                kernel_size=q_stride, stride=q_stride, ceil_mode=False
            )

        self.attn = MultiScaleAttention(
            dim,
            dim_out,
            num_heads=num_heads,
            q_pool=self.pool,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim_out)
        self.mlp = MLP(
            dim_out,
            int(dim_out * mlp_ratio),
            dim_out,
            num_layers=2,
            activation=act_layer,
        )

        if dim != dim_out:
            self.proj = nn.Linear(dim, dim_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x  # B, H, W, C
        x = self.norm1(x)

        # Skip connection
        if self.dim != self.dim_out:
            shortcut = do_pool(self.proj(x), self.pool)

        # Window partition
        window_size = self.window_size
        if window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, window_size)

        # Window Attention + Q Pooling (if stage change)
        x = self.attn(x)
        if self.q_stride:
            # Shapes have changed due to Q pooling
            window_size = self.window_size // self.q_stride[0]
            H, W = shortcut.shape[1:3]

            pad_h = (window_size - H % window_size) % window_size
            pad_w = (window_size - W % window_size) % window_size
            pad_hw = (H + pad_h, W + pad_w)

        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, window_size, pad_hw, (H, W))

        x = shortcut + self.drop_path(x)
        # MLP
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class Hiera(nn.Module):
    """
    Reference: https://arxiv.org/abs/2306.00989
    """

    def __init__(
        self,
        embed_dim: int = 96,  # initial embed dim
        num_heads: int = 1,  # initial number of heads
        drop_path_rate: float = 0.0,  # stochastic depth
        q_pool: int = 3,  # number of q_pool stages
        q_stride: Tuple[int, int] = (2, 2),  # downsample stride bet. stages
        stages: Tuple[int, ...] = (2, 3, 16, 3),  # blocks per stage
        dim_mul: float = 2.0,  # dim_mul factor at stage shift
        head_mul: float = 2.0,  # head_mul factor at stage shift
        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
        # window size per stage, when not using global att.
        window_spec: Tuple[int, ...] = (
            8,
            4,
            14,
            7,
        ),
        # global attn in these blocks
        global_att_blocks: Tuple[int, ...] = (
            12,
            16,
            20,
        ),
        weights_path=None,
        return_interm_layers=True,  # return feats from every stage
        input_size=None,
    ):
        super().__init__()

        assert len(stages) == len(window_spec)
        self.window_spec = window_spec

        depth = sum(stages)
        self.q_stride = q_stride
        self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
        assert 0 <= q_pool <= len(self.stage_ends[:-1])
        self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
        self.return_interm_layers = return_interm_layers

        self.patch_embed = PatchEmbed(
            embed_dim=embed_dim,
        )
        # Which blocks have global att?
        self.global_att_blocks = global_att_blocks

        # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
        self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
        self.pos_embed = nn.Parameter(
            torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
        )
        self.pos_embed_window = nn.Parameter(
            torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
        )

        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule

        cur_stage = 1
        self.blocks = nn.ModuleList()

        for i in range(depth):
            dim_out = embed_dim
            # lags by a block, so first block of
            # next stage uses an initial window size
            # of previous stage and final window size of current stage
            window_size = self.window_spec[cur_stage - 1]

            if self.global_att_blocks is not None:
                window_size = 0 if i in self.global_att_blocks else window_size

            if i - 1 in self.stage_ends:
                dim_out = int(embed_dim * dim_mul)
                num_heads = int(num_heads * head_mul)
                cur_stage += 1

            block = MultiScaleBlock(
                dim=embed_dim,
                dim_out=dim_out,
                num_heads=num_heads,
                drop_path=dpr[i],
                q_stride=self.q_stride if i in self.q_pool_blocks else None,
                window_size=window_size,
            )

            embed_dim = dim_out
            self.blocks.append(block)

        self.channel_list = (
            [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
            if return_interm_layers
            else [self.blocks[-1].dim_out]
        )

        if weights_path is not None:
            with g_pathmgr.open(weights_path, "rb") as f:
                chkpt = torch.load(f, map_location="cpu")
            logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
        self.input_size = input_size
        if self.input_size is not None:
            self.register_buffer(
                "resized_pos_embed",
                self._get_pos_embed((self.input_size[0] // 4, self.input_size[1] // 4)).clone().detach(),
                persistent=False
            )

    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
        h, w = hw
        window_embed = self.pos_embed_window
        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
        pos_embed = pos_embed + window_embed.tile(
            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1)
        return pos_embed

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        x = self.patch_embed(x)
        # x: (B, H, W, C)

        # Add pos embed
        if self.input_size is not None:
            x = x + self.resized_pos_embed
        else:
            x = x + self._get_pos_embed(x.shape[1:3])

        outputs = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if (i == self.stage_ends[-1]) or (
                i in self.stage_ends and self.return_interm_layers
            ):
                feats = x.permute(0, 3, 1, 2)
                outputs.append(feats)

        return outputs

    def get_layer_id(self, layer_name):
        # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
        num_layers = self.get_num_layers()

        if layer_name.find("rel_pos") != -1:
            return num_layers + 1
        elif layer_name.find("pos_embed") != -1:
            return 0
        elif layer_name.find("patch_embed") != -1:
            return 0
        elif layer_name.find("blocks") != -1:
            return int(layer_name.split("blocks")[1].split(".")[1]) + 1
        else:
            return num_layers + 1

    def get_num_layers(self) -> int:
        return len(self.blocks)
  • I change the comple code, this is my compile code
import os.path
import time
import random

import torch
import torch_tensorrt
import tqdm
from torchvision import transforms
from PIL import Image

from sam2.build_sam import build_sam2


def loop(model, input_shape=(1024, 1024), precision=torch.float32):
    with torch.no_grad():
        for _ in range(25):
            src = torch.randn((1, 3, *input_shape), device='cuda', dtype=precision)
            model(src)
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in tqdm.tqdm(range(5000)):
            src = torch.randn((1, 3, *input_shape), device='cuda', dtype=precision)
            model(src)
            torch.cuda.synchronize()
        print(5000 / (time.time() - start_time))


def pil_to_tensor(image, target_size=(224, 224)):

    # 定义预处理流程
    preprocess = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])

    # 应用预处理并添加批次维度
    input_tensor = preprocess(image).unsqueeze(0)
    return input_tensor


def check_precision(new_model, origin_model, input_shape=(1024, 1024), precision=torch.float32,
                    test_image_root=None):
    all_image_list = []
    if test_image_root is not None:
        all_image_list = [os.path.join(test_image_root, each) for each in os.listdir(test_image_root)]
    for i in range(10):
        if all_image_list is None:
            dummy_input = torch.randn((1, 3, *input_shape), device='cuda', dtype=precision)
        else:
            image_name = random.choice(all_image_list)
            real_image = Image.open(image_name)
            print(os.path.basename(image_name))
            dummy_input = pil_to_tensor(real_image, input_shape).to(device='cuda', dtype=precision)
        print(dummy_input.dtype)
        new_model_result = new_model(dummy_input)
        origin_model_result = origin_model(dummy_input.float())
        for new_feature_map, old_feature_map in zip(new_model_result, origin_model_result):
            # old_feature_map = old_feature_map.to(precision)
            print(old_feature_map.shape)
            abs_diff = (new_feature_map - old_feature_map).abs()
            mean_diff = abs_diff.mean().item()
            max_diff = abs_diff.max().item()
            # rel_diff = abs_diff / (old_feature_map.abs() + 1e-7)  # 避免除零
            # mean_rel_diff = rel_diff.mean().item()
            # max_rel_diff = rel_diff.max().item()
            cos_sim = torch.cosine_similarity(new_feature_map, old_feature_map, dim=1)
            mean_cos = cos_sim.mean().item()
            min_cos = cos_sim.min().item()
            print(mean_diff, max_diff, mean_cos, min_cos)


@torch.no_grad()
def main():
    sam2_config = 'configs/sam2.1_only_encoder/sam2.1_hiera_t.yaml'
    sam2_ckpt = './pretrained_checkpoint/sam2.1_hiera_tiny.pt'
    raw_sam2_hiera = build_sam2(sam2_config, ckpt_path=sam2_ckpt, apply_postprocessing=False, removed_prefix='image_encoder.trunk.')
    raw_sam2_hiera.eval()

    sam2_hiera = build_sam2(sam2_config, ckpt_path=sam2_ckpt, apply_postprocessing=False, removed_prefix='image_encoder.trunk.')
    sam2_hiera.eval()

    inference_shape = (672, 896)
    real_check_data_dir = './test_image'
    # compile_type = None
    compile_type = 'dynamo'
    history_compiled_model = "tiny.ep"

    # compile_type = 'tensorrt'
    # history_compiled_model = "tiny.ts"


    if compile_type == 'dynamo':
        if os.path.exists(history_compiled_model):
            sam2_hiera = torch_tensorrt.load(history_compiled_model).module()
            print('load history compiled ep model')
        else:
            if real_check_data_dir is not None:
                all_image_list = [os.path.join(real_check_data_dir, each) for each in os.listdir(real_check_data_dir)]
                dummy_input = pil_to_tensor(
                    Image.open(random.choice(all_image_list)), inference_shape
                ).to(device='cuda', dtype=torch.float16)
            else:
                dummy_input = torch.randn((1, 3, *inference_shape), device='cuda', dtype=torch.float16)
            # sam2_hiera.half()
            print('start compile')
            exp_program = torch.export.export(
                sam2_hiera.half(),
                (dummy_input,),
                strict=True
            )
            print('finish compile stage 1')
            dummy_input = torch_tensorrt.Input(
                shape_mode=0,
                shape=(1, 3, *inference_shape),
                dtype=torch_tensorrt.dtype.float16,
            )
            sam2_hiera = torch_tensorrt.dynamo.compile(
                exp_program,
                inputs=(dummy_input,),
                min_block_size=1,
                enabled_precisions={torch.float16},
                use_fp32_acc=True,
                optimization_level=5,
                # device=torch_tensorrt.Device("dla:0", allow_gpu_fallback=True),
            )
            print('finish compile stage 2')
            torch_tensorrt.save(sam2_hiera, history_compiled_model, inputs=(dummy_input,))
            # sam2_hiera = torch.export.load(history_compiled_model).module()
            sam2_hiera = torch_tensorrt.load(history_compiled_model).module()
            print('finish compile')
            # trt_out = sam2_hiera(dummy_input)
    elif compile_type == 'tensorrt':
        if os.path.exists(history_compiled_model):
            sam2_hiera = torch_tensorrt.load(history_compiled_model)
            print('load history compiled ts model')
        else:
            print('start compile')
            dummy_input = torch.randn((1, 3, *inference_shape), device='cuda', dtype=torch.float16)
            scripted = torch.jit.trace(
                sam2_hiera.half(),
                dummy_input
            )
            scripted = torch.jit.freeze(scripted)
            print('finish compile stage 1')
            # 静态编译
            sam2_hiera = torch_tensorrt.compile(
                scripted,
                # sam2_hiera.half(),
                ir='torchscript',
                inputs=[torch_tensorrt.Input(shape=dummy_input.shape, dtype=torch.float16)],
                enabled_precisions={torch.float16},
                workspace_size=1 << 30,
                truncate_long_and_double=True
            )
            print('finish compile stage 2')
            torch_tensorrt.save(sam2_hiera, history_compiled_model, output_format="torchscript", inputs=(dummy_input,))
            sam2_hiera = torch_tensorrt.load(history_compiled_model)
            print('finish compile')
    else:
        assert compile_type is None
        sam2_hiera = sam2_hiera.half()
    # print(sam2_hiera)
    check_precision(sam2_hiera, raw_sam2_hiera, input_shape=inference_shape,
                    precision=torch.float16, test_image_root=real_check_data_dir)
    loop(sam2_hiera, input_shape=inference_shape, precision=torch.float16)


if __name__ == '__main__':
    main()
  • (Not Important) I modified build_sam.py to only load weight for Hiera.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os

import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf

import sam2

# Check if the user is running Python from the parent directory of the sam2 repo
# (i.e. the directory where this repo is cloned into) -- this is not supported since
# it could shadow the sam2 package and cause issues.
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
    # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
    # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
    # This typically happens because the user is running Python from the parent directory
    # that contains the sam2 repo they cloned.
    raise RuntimeError(
        "You're likely running Python from the parent directory of the sam2 repository "
        "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
        "This is not supported since the `sam2` Python package could be shadowed by the "
        "repository name (the repository is also named `sam2` and contains the Python package "
        "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
        "rather than its parent dir, or from your home directory) after installing SAM 2."
    )


HF_MODEL_ID_TO_FILENAMES = {
    "facebook/sam2-hiera-tiny": (
        "configs/sam2/sam2_hiera_t.yaml",
        "sam2_hiera_tiny.pt",
    ),
    "facebook/sam2-hiera-small": (
        "configs/sam2/sam2_hiera_s.yaml",
        "sam2_hiera_small.pt",
    ),
    "facebook/sam2-hiera-base-plus": (
        "configs/sam2/sam2_hiera_b+.yaml",
        "sam2_hiera_base_plus.pt",
    ),
    "facebook/sam2-hiera-large": (
        "configs/sam2/sam2_hiera_l.yaml",
        "sam2_hiera_large.pt",
    ),
    "facebook/sam2.1-hiera-tiny": (
        "configs/sam2.1/sam2.1_hiera_t.yaml",
        "sam2.1_hiera_tiny.pt",
    ),
    "facebook/sam2.1-hiera-small": (
        "configs/sam2.1/sam2.1_hiera_s.yaml",
        "sam2.1_hiera_small.pt",
    ),
    "facebook/sam2.1-hiera-base-plus": (
        "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "sam2.1_hiera_base_plus.pt",
    ),
    "facebook/sam2.1-hiera-large": (
        "configs/sam2.1/sam2.1_hiera_l.yaml",
        "sam2.1_hiera_large.pt",
    ),
}


def build_sam2(
    config_file,
    ckpt_path=None,
    device="cuda",
    mode="eval",
    hydra_overrides_extra=[],
    apply_postprocessing=True,
    removed_prefix=None,
    **kwargs,
):

    if apply_postprocessing:
        hydra_overrides_extra = hydra_overrides_extra.copy()
        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
        ]
    # Read config and init model
    cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
    OmegaConf.resolve(cfg)
    model = instantiate(cfg.model, _recursive_=True)
    _load_checkpoint(model, ckpt_path, removed_prefix)
    model = model.to(device)
    if mode == "eval":
        model.eval()
    return model


def build_sam2_video_predictor(
    config_file,
    ckpt_path=None,
    device="cuda",
    mode="eval",
    hydra_overrides_extra=[],
    apply_postprocessing=True,
    **kwargs,
):
    hydra_overrides = [
        "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
    ]
    if apply_postprocessing:
        hydra_overrides_extra = hydra_overrides_extra.copy()
        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
            # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
            "++model.binarize_mask_from_pts_for_mem_enc=true",
            # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
            "++model.fill_hole_area=8",
        ]
    hydra_overrides.extend(hydra_overrides_extra)

    # Read config and init model
    cfg = compose(config_name=config_file, overrides=hydra_overrides)
    OmegaConf.resolve(cfg)
    model = instantiate(cfg.model, _recursive_=True)
    _load_checkpoint(model, ckpt_path)
    model = model.to(device)
    if mode == "eval":
        model.eval()
    return model


def _hf_download(model_id):
    from huggingface_hub import hf_hub_download

    config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
    ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
    return config_name, ckpt_path


def build_sam2_hf(model_id, **kwargs):
    config_name, ckpt_path = _hf_download(model_id)
    return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)


def build_sam2_video_predictor_hf(model_id, **kwargs):
    config_name, ckpt_path = _hf_download(model_id)
    return build_sam2_video_predictor(
        config_file=config_name, ckpt_path=ckpt_path, **kwargs
    )


def _load_checkpoint(model, ckpt_path, removed_prefix=None):
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
        if removed_prefix is not None:
            pos = len(removed_prefix)
            sd = {
                k[pos:]: v for k, v in sd.items() if k.startswith(removed_prefix)
            }
        missing_keys, unexpected_keys = model.load_state_dict(sd)
        if missing_keys:
            logging.error(missing_keys)
            raise RuntimeError()
        if unexpected_keys:
            logging.error(unexpected_keys)
            raise RuntimeError()
        logging.info("Loaded checkpoint sucessfully")
  • (Not Important) I modified sam2.1_hiera_t.yaml to only build Hiera.
# @package _global_

# Model
model:
  _target_: sam2.modeling.backbones.hieradet.Hiera
  embed_dim: 96
  num_heads: 1
  stages: [1, 2, 7, 2]
  global_att_blocks: [5, 7, 9]
  window_pos_embed_bkg_spatial_size: [7, 7]
  input_size: [672, 896]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants