Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Visualizing Self-Attention in RT-DETR Encoder #478

Open
Anchor1566 opened this issue Oct 20, 2024 · 3 comments
Open

Question about Visualizing Self-Attention in RT-DETR Encoder #478

Anchor1566 opened this issue Oct 20, 2024 · 3 comments

Comments

@Anchor1566
Copy link

Anchor1566 commented Oct 20, 2024

Hello lyuwenyu,

First of all, thank you for your amazing work on RT-DETR! I’ve just started learning about object detection models, and I truly appreciate the innovations that make RT-DETR both faster and more efficient. I've starred the repository and look forward to diving deeper into the project!

I have a question regarding the self-attention visualization in the encoder of RT-DETR. In the original DETR paper, the multi-layer encoder, with multiple attention heads, progressively focuses on different regions of the image, allowing the model to capture fine-grained features such as object edges, shapes, and contours. These attention maps clearly highlight object structures.

In contrast, RT-DETR simplifies the encoder, reducing it to a single layer to minimize computational overhead and improve inference speed. When visualizing self-attention in RT-DETR, I noticed that the attention maps do not reveal such explicit object shapes or outlines as in DETR.

Is this simplification of the encoder, aimed at reducing computational complexity, responsible for the reduced ability to capture complex object features and therefore the lack of clear object correlations in the attention maps?

As a beginner, I’d love any guidance or insights you could provide on this topic!

Thank you again for your hard work, and I’m excited to continue learning from this project.

d9bb9ff8e2e2277ce34b026a415de841
ce09021fc74d8914ce77cc4eff79409f

@hongu0603
Copy link

Anchor1566 I want to ask you how to visualize the self-attention in Encoder, I also new in learning transformer too.

@Anchor1566
Copy link
Author

Anchor1566 I want to ask you how to visualize the self-attention in Encoder, I also new in learning transformer too.

You can refer to the code for a good starting point. Regarding RTDETR, I can share my implementation with you (I've managed to get the attention weights for the encoder, but I haven't correctly indexed the weights for the decoder yet). If you manage to figure that part out, I'd be really interested in seeing your solution as well.

Wishing you all the best and hope you enjoy the learning process!

import argparse
from pathlib import Path
import sys
import time
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)

from src.core import YAMLConfig 

import torch
from torch import nn
from PIL import Image, ImageDraw
from torchvision import transforms

from ultralytics.utils.plotting import Annotator, colors
import matplotlib.pyplot as plt
import utils.wwpig as wwpig


class ImageReader:
    def __init__(self, resize=640, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.transform = transforms.Compose([
            # transforms.Resize((resize, resize)) if isinstance(resize, int) else transforms.Resize(
            #     (resize[0], resize[1])),
            transforms.ToTensor(),
            # transforms.Normalize(mean=mean, std=std),
        ])
        self.resize = resize
        self.pil_img = None   

    def __call__(self, image_path, *args, **kwargs):
        self.pil_img = Image.open(image_path).convert('RGB').resize((self.resize, self.resize))
        return self.transform(self.pil_img).unsqueeze(0)


class Model(nn.Module):
    def __init__(self, confg=None, ckpt="") -> None:
        super().__init__()
        self.cfg = YAMLConfig(confg, resume=ckpt)
        if ckpt:
            checkpoint = torch.load(ckpt, map_location='cpu') 
            if 'ema' in checkpoint:
                state = checkpoint['ema']['module']
            else:
                state = checkpoint['model']
        else:
            raise AttributeError('only support resume to load model.state_dict by now.')

        # NOTE load train mode state -> convert to deploy mode
        self.cfg.model.load_state_dict(state)

        self.model = self.cfg.model.deploy()
        self.postprocessor = self.cfg.postprocessor.deploy()

        self.backbone = self.model.backbone
        self.encoder = self.model.encoder
        self.decoder = self.model.decoder
        # print(self.postprocessor.deploy_mode)
        
    def forward(self, images, orig_target_sizes):
        outputs = self.model(images)
        return self.postprocessor(outputs, orig_target_sizes)



def get_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", '-c', type=str, default='./configs/rtdetrv2/rtdetrv2_r50vd_m_7x_coco.yml')
    parser.add_argument("--ckpt", '-w', type=str, default=r"C:/Users/zbx/Desktop/code/SDU/visual/pt/rtdetrv2_r50vd_m_7x_coco_ema.pth") # pth
    parser.add_argument("--image", '-i', type=str, default=r"C:/Users/zbx/Desktop/code/SDU/dataset/visualize/cat.jpg") 
    parser.add_argument("--visualize", help="visualize the result of encoder and decoder", default=True)
    # parser.add_argument("--output_dir", '-o', type=str, default='./output')

    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return args


@torch.no_grad()
def main(args):
    img_path = Path(args.image)
    device = torch.device(args.device)
    reader = ImageReader(resize=640)
    model = Model(confg=args.config, ckpt=args.ckpt)
    model.to(device=device)

    img = reader(img_path).to(device)
    size = torch.tensor([img.shape[2], img.shape[3]]).to(device)
    
    start_time = time.time()
    output = model(img, size)
    inf_time = time.time() - start_time
    fps = float(1/inf_time)
    print("Inferece time = {} s".format(inf_time, '.4f'))
    print("FPS = {} ".format(fps, '.1f') )
    
    labels, boxes, scores = output
    
    im = reader.pil_img
    annotator = Annotator(im, line_width=3, font_size=16, )
    thrh = 0.6

    for i in range(img.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        for j, b in enumerate(box):
            annotator.box_label(b, label=f"{lab[j]}  {scr[j]:.2f}", color=colors(lab[j], True))

    if args.visualize:
        visualizer = wwpig.AttentionVisualizer(model)
        refer_points = [(323, 301), (185, 345), (291, 371), (633, 633)]
        visualizer.plot_attention(img, size, refer_points)


    # save_path = Path(args.output_dir) / img_path.name
    # file_dir = os.path.dirname(args.image)
    # new_file_name = os.path.basename(args.image).split('.')[0] + '_torch'+ os.path.splitext(args.image)[1]
    # new_file_path = file_dir + '/' + new_file_name
    # print('new_file_path: ', new_file_path)
    # im.save(new_file_path)
 

if __name__ == "__main__":
    args = get_argparser()
    main(args)
    ```
    ```
    import torch
import torch.nn as nn
from torch.nn import functional as F

import matplotlib.pyplot as plt

from torchvision import transforms
from PIL import Image


import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from typing import List, Tuple, Any


class AttentionVisualizer:
    def __init__(self, model: nn.Module):
        self.model = model
        self.conv_features = []
        self.enc_attn_weights = []
        self.dec_attn_weights = []

    def register_hooks(self) -> List[Any]:
        """Register forward hooks on the model."""
        def get_hook(target_list):
            def hook(module, input, output):
                target_list.append(output if isinstance(output, torch.Tensor) else output[1])
            return hook

        hooks = [
            self.model.backbone.res_layers[-1].register_forward_hook(
                get_hook(self.conv_features)
            ),
            self.model.encoder.encoder[-1].layers[-1].self_attn.register_forward_hook(
                get_hook(self.enc_attn_weights)
            ),
            self.model.decoder.decoder.layers[-1].cross_attn.register_forward_hook(
                get_hook(self.dec_attn_weights)
            )
        ]
        return hooks

    def remove_hooks(self, hooks: List[Any]):
        """Remove registered hooks."""
        for hook in hooks:
            hook.remove()

    def get_attention_maps(self, image: torch.Tensor, size: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get attention maps from the model."""
        hooks = self.register_hooks()
        self.model.model(image, size)
        self.remove_hooks(hooks)

        conv_features = self.conv_features[0]
        enc_attn_weights = self.enc_attn_weights[0]
        
        h, w = conv_features.shape[-2:]
        sattn = enc_attn_weights[0].reshape(h, w, h, w).cpu()
        
        return conv_features, sattn

    def plot_attention(self, image: torch.Tensor, size: Tuple[int, int], reference_points: List[Tuple[int, int]]):
        """Plot attention maps for given reference points."""
        conv_features, sattn = self.get_attention_maps(image, size)
        h, w = conv_features.shape[-2:]
        fact = int(size[0]) / h

        fig = plt.figure(constrained_layout=True, figsize=(25 * 0.7, 8.5 * 0.7))
        gs = fig.add_gridspec(2, 4)
        
        axs = [fig.add_subplot(gs[i, j]) for i, j in [(0, 0), (1, 0), (0, -1), (1, -1)]]

        for idx_o, ax in zip(reference_points, axs):
            idx = (int(idx_o[0] // fact), int(idx_o[1] // fact))
            print(idx)
            ax.imshow(sattn[idx[0], idx[1]], cmap='cividis', interpolation='nearest')
            ax.axis('off')
            ax.set_title(f'self-attention{idx_o}')

        # Hack implementation
        to_pil = transforms.ToPILImage()
        im = to_pil(image.squeeze(0))
        
        fcenter_ax = fig.add_subplot(gs[:, 1:-1])
        fcenter_ax.imshow(im)
        for (y, x) in reference_points:
            scale = im.height / image.shape[-2]
            x = ((x // fact) + 0.5) * fact
            y = ((y // fact) + 0.5) * fact
            fcenter_ax.add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))
        fcenter_ax.axis('off')

        plt.show()
        ```
    

@hongu0603
Copy link

Anchor1566 I want to ask you how to visualize the self-attention in Encoder, I also new in learning transformer too.

You can refer to the code for a good starting point. Regarding RTDETR, I can share my implementation with you (I've managed to get the attention weights for the encoder, but I haven't correctly indexed the weights for the decoder yet). If you manage to figure that part out, I'd be really interested in seeing your solution as well.

Wishing you all the best and hope you enjoy the learning process!

import argparse
from pathlib import Path
import sys
import time
import os
current_dir = os.path.dirname(os.path.realpath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)

from src.core import YAMLConfig 

import torch
from torch import nn
from PIL import Image, ImageDraw
from torchvision import transforms

from ultralytics.utils.plotting import Annotator, colors
import matplotlib.pyplot as plt
import utils.wwpig as wwpig


class ImageReader:
    def __init__(self, resize=640, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.transform = transforms.Compose([
            # transforms.Resize((resize, resize)) if isinstance(resize, int) else transforms.Resize(
            #     (resize[0], resize[1])),
            transforms.ToTensor(),
            # transforms.Normalize(mean=mean, std=std),
        ])
        self.resize = resize
        self.pil_img = None   

    def __call__(self, image_path, *args, **kwargs):
        self.pil_img = Image.open(image_path).convert('RGB').resize((self.resize, self.resize))
        return self.transform(self.pil_img).unsqueeze(0)


class Model(nn.Module):
    def __init__(self, confg=None, ckpt="") -> None:
        super().__init__()
        self.cfg = YAMLConfig(confg, resume=ckpt)
        if ckpt:
            checkpoint = torch.load(ckpt, map_location='cpu') 
            if 'ema' in checkpoint:
                state = checkpoint['ema']['module']
            else:
                state = checkpoint['model']
        else:
            raise AttributeError('only support resume to load model.state_dict by now.')

        # NOTE load train mode state -> convert to deploy mode
        self.cfg.model.load_state_dict(state)

        self.model = self.cfg.model.deploy()
        self.postprocessor = self.cfg.postprocessor.deploy()

        self.backbone = self.model.backbone
        self.encoder = self.model.encoder
        self.decoder = self.model.decoder
        # print(self.postprocessor.deploy_mode)
        
    def forward(self, images, orig_target_sizes):
        outputs = self.model(images)
        return self.postprocessor(outputs, orig_target_sizes)



def get_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", '-c', type=str, default='./configs/rtdetrv2/rtdetrv2_r50vd_m_7x_coco.yml')
    parser.add_argument("--ckpt", '-w', type=str, default=r"C:/Users/zbx/Desktop/code/SDU/visual/pt/rtdetrv2_r50vd_m_7x_coco_ema.pth") # pth
    parser.add_argument("--image", '-i', type=str, default=r"C:/Users/zbx/Desktop/code/SDU/dataset/visualize/cat.jpg") 
    parser.add_argument("--visualize", help="visualize the result of encoder and decoder", default=True)
    # parser.add_argument("--output_dir", '-o', type=str, default='./output')

    args = parser.parse_args()
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return args


@torch.no_grad()
def main(args):
    img_path = Path(args.image)
    device = torch.device(args.device)
    reader = ImageReader(resize=640)
    model = Model(confg=args.config, ckpt=args.ckpt)
    model.to(device=device)

    img = reader(img_path).to(device)
    size = torch.tensor([img.shape[2], img.shape[3]]).to(device)
    
    start_time = time.time()
    output = model(img, size)
    inf_time = time.time() - start_time
    fps = float(1/inf_time)
    print("Inferece time = {} s".format(inf_time, '.4f'))
    print("FPS = {} ".format(fps, '.1f') )
    
    labels, boxes, scores = output
    
    im = reader.pil_img
    annotator = Annotator(im, line_width=3, font_size=16, )
    thrh = 0.6

    for i in range(img.shape[0]):

        scr = scores[i]
        lab = labels[i][scr > thrh]
        box = boxes[i][scr > thrh]

        for j, b in enumerate(box):
            annotator.box_label(b, label=f"{lab[j]}  {scr[j]:.2f}", color=colors(lab[j], True))

    if args.visualize:
        visualizer = wwpig.AttentionVisualizer(model)
        refer_points = [(323, 301), (185, 345), (291, 371), (633, 633)]
        visualizer.plot_attention(img, size, refer_points)


    # save_path = Path(args.output_dir) / img_path.name
    # file_dir = os.path.dirname(args.image)
    # new_file_name = os.path.basename(args.image).split('.')[0] + '_torch'+ os.path.splitext(args.image)[1]
    # new_file_path = file_dir + '/' + new_file_name
    # print('new_file_path: ', new_file_path)
    # im.save(new_file_path)
 

if __name__ == "__main__":
    args = get_argparser()
    main(args)
    ```
    ```
    import torch
import torch.nn as nn
from torch.nn import functional as F

import matplotlib.pyplot as plt

from torchvision import transforms
from PIL import Image


import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from typing import List, Tuple, Any


class AttentionVisualizer:
    def __init__(self, model: nn.Module):
        self.model = model
        self.conv_features = []
        self.enc_attn_weights = []
        self.dec_attn_weights = []

    def register_hooks(self) -> List[Any]:
        """Register forward hooks on the model."""
        def get_hook(target_list):
            def hook(module, input, output):
                target_list.append(output if isinstance(output, torch.Tensor) else output[1])
            return hook

        hooks = [
            self.model.backbone.res_layers[-1].register_forward_hook(
                get_hook(self.conv_features)
            ),
            self.model.encoder.encoder[-1].layers[-1].self_attn.register_forward_hook(
                get_hook(self.enc_attn_weights)
            ),
            self.model.decoder.decoder.layers[-1].cross_attn.register_forward_hook(
                get_hook(self.dec_attn_weights)
            )
        ]
        return hooks

    def remove_hooks(self, hooks: List[Any]):
        """Remove registered hooks."""
        for hook in hooks:
            hook.remove()

    def get_attention_maps(self, image: torch.Tensor, size: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get attention maps from the model."""
        hooks = self.register_hooks()
        self.model.model(image, size)
        self.remove_hooks(hooks)

        conv_features = self.conv_features[0]
        enc_attn_weights = self.enc_attn_weights[0]
        
        h, w = conv_features.shape[-2:]
        sattn = enc_attn_weights[0].reshape(h, w, h, w).cpu()
        
        return conv_features, sattn

    def plot_attention(self, image: torch.Tensor, size: Tuple[int, int], reference_points: List[Tuple[int, int]]):
        """Plot attention maps for given reference points."""
        conv_features, sattn = self.get_attention_maps(image, size)
        h, w = conv_features.shape[-2:]
        fact = int(size[0]) / h

        fig = plt.figure(constrained_layout=True, figsize=(25 * 0.7, 8.5 * 0.7))
        gs = fig.add_gridspec(2, 4)
        
        axs = [fig.add_subplot(gs[i, j]) for i, j in [(0, 0), (1, 0), (0, -1), (1, -1)]]

        for idx_o, ax in zip(reference_points, axs):
            idx = (int(idx_o[0] // fact), int(idx_o[1] // fact))
            print(idx)
            ax.imshow(sattn[idx[0], idx[1]], cmap='cividis', interpolation='nearest')
            ax.axis('off')
            ax.set_title(f'self-attention{idx_o}')

        # Hack implementation
        to_pil = transforms.ToPILImage()
        im = to_pil(image.squeeze(0))
        
        fcenter_ax = fig.add_subplot(gs[:, 1:-1])
        fcenter_ax.imshow(im)
        for (y, x) in reference_points:
            scale = im.height / image.shape[-2]
            x = ((x // fact) + 0.5) * fact
            y = ((y // fact) + 0.5) * fact
            fcenter_ax.add_patch(plt.Circle((x * scale, y * scale), fact // 2, color='r'))
        fcenter_ax.axis('off')

        plt.show()
        ```
    

Thanks for the code, but i have some question. What is the purpose of wwpig in import utils.wwpig as wwpig ? How can I install it? After running pip install utils, I still get the error No module named 'utils.wwpig'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants