Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the possiblity of supporting batch input #80

Closed
noahzn opened this issue Jul 3, 2024 · 12 comments
Closed

the possiblity of supporting batch input #80

noahzn opened this issue Jul 3, 2024 · 12 comments

Comments

@noahzn
Copy link

noahzn commented Jul 3, 2024

Hi @fabio-sim Now the repo only supports for batchsize =1, do you think it's possible that if not enough keypoints are extracted, we can use a random array to make them have the same number of keypoints. For example, if the input is 2XNX2, for image 1 N1=128, for image 2 N2=125, can we stack three random array as the fake points so that we can run it in a batch mode?

@noahzn noahzn changed the title the possiblity of supporting in batch input the possiblity of supporting batch input Jul 3, 2024
@fabio-sim
Copy link
Owner

Hello @noahzn, thanks for your interest again. I'll see what I can do.

@noahzn
Copy link
Author

noahzn commented Jul 11, 2024

Thank you! I will be waiting for your thoughts.

@fabio-sim
Copy link
Owner

I've added batch input support in 9ebf215. Rather than padding with a random array, I've decided to go with another design choice instead; details here: https://fabio-sim.github.io/blog/accelerating-lightglue-inference-onnx-runtime-tensorrt/

@noahzn
Copy link
Author

noahzn commented Jul 17, 2024

That's really amazing! I will take a careful look and give you feedback. Thanks a lot!

@noahzn
Copy link
Author

noahzn commented Jul 18, 2024

Hi @fabio-sim , I noticed that you also modified this file, but you didn't use it in exporting models. Can I use it if I want to export non-end2end models using batch input? My two image batch have different numbers of keypoints. For example, keypoints of image1 are always (B X 100 X 2), and image2's are always (B X 200 X2)

@fabio-sim
Copy link
Owner

Hi, that file is from the original impl, so it's unrelated to export.

For your use case, I recommend passing the left and right batches separately then, like this: (note: untested):

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from ..config import Extractor
from ..ops import multi_head_attention_dispatch

torch.backends.cudnn.deterministic = True


class LearnableFourierPositionalEncoding(nn.Module):
    def __init__(self, M: int, descriptor_dim: int, num_heads: int, gamma: float = 1.0) -> None:
        super().__init__()
        self.num_heads = num_heads
        head_dim = descriptor_dim // num_heads
        self.Wr = nn.Linear(M, head_dim // 2, bias=False)
        self.gamma = gamma
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """encode position vector"""
        projected = self.Wr(x)
        cosines, sines = torch.cos(projected), torch.sin(projected)
        emb = torch.stack([cosines, sines])
        return emb.repeat_interleave(2, dim=3).repeat(1, 1, 1, self.num_heads).unsqueeze(4)


class TokenConfidence(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())

    def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
        """get confidence tokens"""
        return (
            self.token(desc0.detach()).squeeze(-1),
            self.token(desc1.detach()).squeeze(-1),
        )


class SelfBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def forward(self, x: torch.Tensor, encoding: torch.Tensor) -> torch.Tensor:
        b, n, _ = x.shape
        qkv: torch.Tensor = self.Wqkv(x)
        qkv = qkv.reshape((b, n, self.embed_dim, 3))
        qk, v = qkv[..., :2], qkv[..., 2]
        qk = self.apply_cached_rotary_emb(encoding, qk)
        q, k = qk[..., 0], qk[..., 1]
        context = multi_head_attention_dispatch(q, k, v, self.num_heads)
        message = self.out_proj(context)
        return x + self.ffn(torch.concat([x, message], 2))

    def rotate_half(self, qk: torch.Tensor) -> torch.Tensor:
        b, n, _, _ = qk.shape
        qk = qk.reshape((b, n, self.num_heads, self.head_dim // 2, 2, 2))
        qk = torch.stack((-qk[..., 1, :], qk[..., 0, :]), dim=4)
        qk = qk.reshape((b, n, self.embed_dim, 2))
        return qk

    def apply_cached_rotary_emb(self, encoding: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
        return qk * encoding[0] + self.rotate_half(qk) * encoding[1]


class CrossBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.to_qk = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.to_out = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        b, _, _ = descriptors0.shape
        qk0, v0 = self.to_qk(descriptors0), self.to_v(descriptors0)
        qk1, v1 = self.to_qk(descriptors1), self.to_v(descriptors1)

        m0 = multi_head_attention_dispatch(qk0, qk1, v1, self.num_heads)
        m0 = self.to_out(m0)
        descriptors0 = descriptors0 + self.ffn(torch.concat([descriptors0, m0], 2))

        m1 = multi_head_attention_dispatch(qk1, qk0, v0, self.num_heads)
        m1 = self.to_out(m1)
        descriptors1 = descriptors1 + self.ffn(torch.concat([descriptors1, m1], 2))
        return descriptors0, descriptors1


class TransformerLayer(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.self_attn = SelfBlock(embed_dim, num_heads)
        self.cross_attn = CrossBlock(embed_dim, num_heads)

    def forward(
        self, descriptors0: torch.Tensor, descriptors1: torch.Tensor, encodings0: torch.Tensor, encodings1: torch.Tensor
    ) -> torch.Tensor:
        descriptors0 = self.self_attn(descriptors0, encodings0)
        descriptors1 = self.self_attn(descriptors1, encodings1)
        return self.cross_attn(descriptors0, descriptors1)


def sigmoid_log_double_softmax(similarities: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor:
    """create the log assignment matrix from logits and similarity"""
    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
    scores0 = F.log_softmax(similarities, 2)
    scores1 = F.log_softmax(similarities, 1)
    scores = scores0 + scores1 + certainties
    return scores


class MatchAssignment(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.scale = dim**0.25
        self.final_proj = nn.Linear(dim, dim, bias=True)
        self.matchability = nn.Linear(dim, 1, bias=True)

    def forward(self, descriptors0: torch.Tensor, descriptors1: torch.Tensor) -> torch.Tensor:
        """build assignment matrix from descriptors"""
        mdescriptors0 = self.final_proj(descriptors0) / self.scale
        mdescriptors1 = self.final_proj(descriptors1) / self.scale
        similarities = mdescriptors0 @ mdescriptors1.transpose(1, 2)
        z0 = self.matchability(descriptors0)
        z1 = self.matchability(descriptors1)
        scores = sigmoid_log_double_softmax(similarities, z0, z1)
        return scores

    def get_matchability(self, desc: torch.Tensor):
        return torch.sigmoid(self.matchability(desc)).squeeze(-1)


def filter_matches(scores: torch.Tensor, threshold: float):
    """obtain matches from a log assignment matrix [BxNxN]"""
    max0 = torch.topk(scores, k=1, dim=2, sorted=False)  # scores.max(2)
    max1 = torch.topk(scores, k=1, dim=1, sorted=False)  # scores.max(1)
    m0, m1 = max0.indices[:, :, 0], max1.indices[:, 0, :]

    indices = torch.arange(m0.shape[1], device=m0.device).expand_as(m0)
    mutual = indices == m1.gather(1, m0)
    mscores = max0.values[:, :, 0].exp()
    valid = mscores > threshold

    b_idx, m0_idx = torch.where(valid & mutual)
    m1_idx = m0[b_idx, m0_idx]
    matches = torch.concat([b_idx[:, None], m0_idx[:, None], m1_idx[:, None]], 1)
    mscores = mscores[b_idx, m0_idx]
    return matches, mscores


class LightGlue(nn.Module):
    version = "v0.1_arxiv"
    url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"

    def __init__(
        self,
        extractor: Extractor,
        descriptor_dim: int = 256,
        num_heads: int = 4,
        n_layers: int = 9,
        filter_threshold: float = 0.1,  # match threshold
        depth_confidence: float = -1,  # -1 is no early stopping, recommend: 0.95
        width_confidence: float = -1,  # -1 is no point pruning, recommend: 0.99
    ) -> None:
        super().__init__()

        self.descriptor_dim = descriptor_dim
        self.num_heads = num_heads
        self.n_layers = n_layers
        self.filter_threshold = filter_threshold
        self.depth_confidence = depth_confidence
        self.width_confidence = width_confidence

        if extractor.dim != self.descriptor_dim:
            self.input_proj = nn.Linear(extractor.dim, self.descriptor_dim, bias=True)
        else:
            self.input_proj = nn.Identity()

        self.posenc = LearnableFourierPositionalEncoding(2, self.descriptor_dim, self.num_heads)

        d, h, n = self.descriptor_dim, self.num_heads, self.n_layers

        self.transformers = nn.ModuleList([TransformerLayer(d, h) for _ in range(n)])

        self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])

        self.token_confidence = nn.ModuleList([TokenConfidence(d) for _ in range(n - 1)])
        self.register_buffer(
            "confidence_thresholds",
            torch.Tensor([self.confidence_threshold(i) for i in range(n)]),
        )

        state_dict = torch.hub.load_state_dict_from_url(self.url.format(self.version, extractor.value))

        # rename old state dict entries
        for i in range(n):
            pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
            state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
            pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
            state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
        self.load_state_dict(state_dict, strict=False)

    def forward(
        self,
        keypoints0: torch.Tensor,  # (2B, N, 2), normalized
        keypoints1: torch.Tensor,
        descriptors0: torch.Tensor,  # (2B, N, D)
        descriptors1: torch.Tensor,
    ):
        descriptors0 = self.input_proj(descriptors0)
        descriptors1 = self.input_proj(descriptors1)

        # positional embeddings
        encodings0 = self.posenc(keypoints0)  # (2, 2B, *, 64, 1)
        encodings1 = self.posenc(keypoints1)

        # GNN + final_proj + assignment
        for i in range(self.n_layers):
            # self+cross attention
            descriptors0, descriptors1 = self.transformers[i](descriptors0, descriptors1, encodings0, encodings1)

        scores = self.log_assignment[i](descriptors0, descriptors1)  # (B, N, N)
        matches, mscores = filter_matches(scores, self.filter_threshold)
        return matches, mscores  # (M, 3), (M,)

    def confidence_threshold(self, layer_index: int) -> float:
        """scaled confidence threshold"""
        threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.n_layers)
        return np.clip(threshold, 0, 1)

    def get_pruning_mask(
        self,
        confidences: torch.Tensor | None,
        scores: torch.Tensor,
        layer_index: int,
    ) -> torch.Tensor:
        """mask points which should be removed"""
        keep = scores > (1 - self.width_confidence)
        if confidences is not None:  # Low-confidence points are never pruned.
            keep |= confidences <= self.confidence_thresholds[layer_index]
        return keep

    def check_if_stop(
        self,
        confidences0: torch.Tensor,
        confidences1: torch.Tensor,
        layer_index: int,
        num_points: int,
    ) -> torch.Tensor:
        """evaluate stopping condition"""
        confidences = torch.cat([confidences0, confidences1], -1)
        threshold = self.confidence_thresholds[layer_index]
        ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
        return ratio_confident > self.depth_confidence

and then adjusting the Pipeline class to orchestrate SuperPoint(100) and SuperPoint(200) accordingly.

@noahzn
Copy link
Author

noahzn commented Jul 20, 2024

Hi @fabio-sim thank you very much! I'm now working on the code. But I met an error

fused_multi_head_attention = torch.library.custom_op(CUSTOM_OP_NAME, mutates_args=())(multi_head_attention)
AttributeError: module 'torch.library' has no attribute 'custom_op'

My torch is >=2.1

@fabio-sim
Copy link
Owner

Oh apologies, my mistake. torch.library.custom_op needs torch >= 2.4. I should've put a check.
I think it's fine if you comment it out

@noahzn
Copy link
Author

noahzn commented Jul 22, 2024

@fabio-sim Thank you for your comments.

orig_image0 = cv2.imread(img0_path, cv2.IMREAD_COLOR)
orig_image1 = cv2.imread(img1_path, cv2.IMREAD_COLOR)
viz2d.plot_images(
    [orig_image0, orig_image1]
)

assert np.all(kpts0[2][matches[..., 1]] == kpts0[0][matches[..., 1]])
assert np.all(kpts1[2][matches[..., 2]] == kpts1[0][matches[..., 2]])
viz2d.plot_matches(kpts0[0][matches[..., 1]], kpts1[0][matches[...,2]], color="lime", lw=0.2)

viz2d.save_plot('aaa1.jpg', dpi=300)
viz2d.plt.show()
viz2d.plot_matches(kpts0[2][matches[..., 1]], kpts1[2][matches[..., 2]], color="lime", lw=0.2)
viz2d.save_plot('aaa2.jpg', dpi=300)
viz2d.plt.show()

I used the above code to visualize. I used batchsize=4, and for the first and the third image pairs, they are the same, and for the other two image pairs I used random arrays. Here I assert that the output for the first and the third pairs are the same. However, when visualizing the results, there are always several matches are changing and incorrect. Do you know the reason?

myplot1

myplot2

update: The problem has been solved. I didn't parse the returned matches correctly. Now it works. Thanks a million for your help!! Now I close this ticket

@noahzn noahzn closed this as completed Jul 22, 2024
@noahzn
Copy link
Author

noahzn commented Jul 22, 2024

Hi, sorry, I still have a problem.

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    head_dim = d // num_heads
    q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v))
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

I found that when image pairs have different numbers of keypoints, the multi_head_attention will throw an error.
For example, for the left images the dimension is (2, 99, 64), and for the right images the dimension is (2, 256, 64). Here 256 is the max_number of keypoints I set. but it extracts 99 keypoints on the left images. Then in the multi_head_attention function it throws the error
q, k, v = (t.reshape((b, n, num_heads, head_dim)).transpose(1, 2) for t in (q, k, v)) RuntimeError: shape '[2, 99, 4, 16]' is invalid for input of size 32768

because for the right images it's [2, 256, 4, 16]: 2x256x4x16=32768 (Please notice here that my keypoint descriptor is 64D, instead of 256D. It's a customized network).

I tried to modify the function as follows, the error was gone but the matching result is bad.

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    nk = k.shape[1]
    head_dim = d // num_heads
    q = q.reshape(b, n, num_heads, head_dim).transpose(1, 2)
    k = k.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
    v = v.reshape(b, nk, num_heads, head_dim).transpose(1, 2)
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

Could you help me with that? Thank you in advance!

Update: I have used the old implementation for CrossBlock and it works with different numbers of keypoints.

@noahzn noahzn reopened this Jul 22, 2024
@fabio-sim
Copy link
Owner

Ah yes, if you have different number of keypoints, that means the sequence length of Q is different from that of K & V.

Use something like this instead:

def multi_head_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int) -> torch.Tensor:
    b, n, d = q.shape
    _, n1, _ = k.shape
    head_dim = d // num_heads
    q = q.reshape((b, n, num_heads, head_dim)).transpose(1, 2)
    k, v = (t.reshape((b, n1, num_heads, head_dim)).transpose(1, 2) for t in (k, v))
    return F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape((b, n, d))

@noahzn noahzn closed this as completed Jul 24, 2024
@noahzn
Copy link
Author

noahzn commented Jul 24, 2024

Thank you again for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants