Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualizing MAE predictions? #17

Open
Wiqzard opened this issue Jul 13, 2023 · 4 comments
Open

Visualizing MAE predictions? #17

Wiqzard opened this issue Jul 13, 2023 · 4 comments

Comments

@Wiqzard
Copy link

Wiqzard commented Jul 13, 2023

Great work.
I wonder if there is an easy way to retrieve the predicted images from the mae inference?

@chayryali
Copy link
Contributor

chayryali commented Jul 13, 2023

Hi there. By default, our pretraining uses normalized pixel targets, so we do not support visualizing predictions. But if you really wanted to, you could adapt the corresponding code from the mae repo. You could potentially un-normalize using the statistics from the visible tokens (or maybe even from the actual targets if it's just a sanity check).

@GiilDe
Copy link

GiilDe commented Aug 8, 2023

@Wiqzard did you end up implementing this?

@DominikFle
Copy link

DominikFle commented Nov 23, 2023

Something like this should work:

import torch
from matplotlib import pyplot as plt
import numpy as np


def vis_hiera_mae(
    x: torch.Tensor,
    preds: torch.Tensor,
    labels: torch.Tensor,
    mask: torch.Tensor,
    path: str,
    mask_ratio: float = 0.6,
    mask_size: int = 32,
    color_zero_patch: float = 0.5,
    image_index_to_show: int = 0,
):
    """
    Visualizes one image from batch (default the first one)
    Args:
        x: [B,C,H,W] input image
        preds: [B*N*mask_ratio,3*32*32] N=HW/32/32
        labels: [B*N*mask_ratio,3*32*32]
        mask: [B,N]
        path: str path to save the image
        mask_ratio: float between 0 and 1
        mask_size: int 32
        color_zero_patch: number between 0 and 1
        image_index_to_show: int between 0 and (B-1)
    """
    x = x.cpu().detach().clone()[image_index_to_show : image_index_to_show + 1]
    preds = preds.cpu().detach().clone()
    labels = labels.cpu().detach().clone()
    mask = mask.cpu().detach().clone()
    H = x.shape[2]
    W = x.shape[3]
    H_strided = int(H / mask_size)
    W_strided = int(W / mask_size)
    B = mask.shape[0]
    N = mask.shape[1]
    # normalize the input image
    x_input = x.clone().permute(0, 2, 3, 1)
    # see https://github.com/facebookresearch/hiera/blob/1f825a3f1b124fca95f0f652a9f57f6758f55233/hiera/hiera_mae.py#L147
    x_input = x_input.unfold(1, mask_size, mask_size).unfold(2, mask_size, mask_size)
    x_patched = x_input.flatten(1, 2).flatten(2)
    x_mean = x_patched.mean(dim=-1, keepdim=True)
    x_var = x_patched.var(dim=-1, keepdim=True)
    x_patched_normalized = (x_patched - x_mean) / (
        x_var + 1.0e-6
    ) ** 0.5  # [1, 19*19, 3*32*32]

    # n_patches = preds.shape[0]
    n_patches_calc = int(np.ceil(N * mask_ratio))
    # assert n_patches == n_patches_calc, f"{n_patches} != {n_patches_calc}"
    # n_patches_mask = (1 - mask[0].float()).sum()
    # assert n_patches_mask == n_patches_calc, f"{n_patches_mask} != {n_patches_calc}"
    preds_new = preds.reshape(B, n_patches_calc, -1)
    labels_new = labels.reshape(B, n_patches_calc, -1)
    preds_with_gt = torch.zeros((N, 3 * mask_size * mask_size))
    labels_with_gt = torch.zeros((N, 3 * mask_size * mask_size))
    gt_with_mask = torch.zeros((N, 3 * mask_size * mask_size))
    zero_patch = torch.ones((3, mask_size, mask_size)).flatten() * color_zero_patch

    b = image_index_to_show
    i = 0
    for n in range(N):
        if mask[b, n] == 0:
            preds_with_gt[n, :] = preds_new[b, i, :]
            labels_with_gt[n, :] = labels_new[b, i, :]
            gt_with_mask[n, :] = zero_patch
            i += 1
        else:
            # insert gts here
            preds_with_gt[n, :] = x_patched_normalized[b, n, :]
            labels_with_gt[n, :] = x_patched_normalized[b, n, :]
            gt_with_mask[n, :] = x_patched[b, n, :]
    # un normalize
    labels_with_gt = labels_with_gt * (x_var + 1.0e-6) ** 0.5 + x_mean
    labels_with_gt = labels_with_gt - labels_with_gt.min()
    labels_with_gt = labels_with_gt / labels_with_gt.max()

    preds_with_gt = preds_with_gt * (x_var + 1.0e-6) ** 0.5 + x_mean
    preds_with_gt = preds_with_gt - preds_with_gt.min()
    preds_with_gt = preds_with_gt / preds_with_gt.max()
    preds_with_gt = preds_with_gt.reshape(H_strided, W_strided, 3, mask_size, mask_size)
    labels_with_gt = labels_with_gt.reshape(
        H_strided, W_strided, 3, mask_size, mask_size
    )
    gt_with_mask = gt_with_mask.reshape(H_strided, W_strided, 3, mask_size, mask_size)
    preds_with_gt = preds_with_gt.permute(0, 1, 3, 4, 2)
    labels_with_gt = labels_with_gt.permute(0, 1, 3, 4, 2)
    gt_with_mask = gt_with_mask.permute(
        0, 1, 3, 4, 2
    )  # [ H_strided, W_strided, mask_size, mask_size, 3]
    # build back together the image
    pred_img = torch.zeros((H, W, 3))
    label_img = torch.zeros((H, W, 3))
    gt_mask_img = torch.zeros((H, W, 3))
    for i in range(H_strided):
        for j in range(W_strided):
            pred_img[
                i * mask_size : (i + 1) * mask_size,
                j * mask_size : (j + 1) * mask_size,
                :,
            ] = preds_with_gt[i, j, :, :, :]
            label_img[
                i * mask_size : (i + 1) * mask_size,
                j * mask_size : (j + 1) * mask_size,
                :,
            ] = labels_with_gt[i, j, :, :, :]
            gt_mask_img[
                i * mask_size : (i + 1) * mask_size,
                j * mask_size : (j + 1) * mask_size,
                :,
            ] = gt_with_mask[i, j, :, :, :]
    img_total = torch.cat((x[0].permute(1, 2, 0), gt_mask_img, pred_img), dim=1).numpy()
    plt.imsave(path, img_total)

On the left is the gt, in the middle the gt with the masked tokens grey and on the right the gt with the predicted tokens where the mask was in the middle. the prediction on the right are with random weights.
hiera_mae_output

@guokeqianhg
Copy link

Something like this should work:

import torch
from matplotlib import pyplot as plt
import numpy as np


def vis_hiera_mae(
    x: torch.Tensor,
    preds: torch.Tensor,
    labels: torch.Tensor,
    mask: torch.Tensor,
    path: str,
    mask_ratio: float = 0.6,
    mask_size: int = 32,
    color_zero_patch: float = 0.5,
    image_index_to_show: int = 0,
):
    """
    Visualizes one image from batch (default the first one)
    Args:
        x: [B,C,H,W] input image
        preds: [B*N*mask_ratio,3*32*32] N=HW/32/32
        labels: [B*N*mask_ratio,3*32*32]
        mask: [B,N]
        path: str path to save the image
        mask_ratio: float between 0 and 1
        mask_size: int 32
        color_zero_patch: number between 0 and 1
        image_index_to_show: int between 0 and (B-1)
    """
    x = x.cpu().detach().clone()[image_index_to_show : image_index_to_show + 1]
    preds = preds.cpu().detach().clone()
    labels = labels.cpu().detach().clone()
    mask = mask.cpu().detach().clone()
    H = x.shape[2]
    W = x.shape[3]
    H_strided = int(H / mask_size)
    W_strided = int(W / mask_size)
    B = mask.shape[0]
    N = mask.shape[1]
    # normalize the input image
    x_input = x.clone().permute(0, 2, 3, 1)
    # see https://github.com/facebookresearch/hiera/blob/1f825a3f1b124fca95f0f652a9f57f6758f55233/hiera/hiera_mae.py#L147
    x_input = x_input.unfold(1, mask_size, mask_size).unfold(2, mask_size, mask_size)
    x_patched = x_input.flatten(1, 2).flatten(2)
    x_mean = x_patched.mean(dim=-1, keepdim=True)
    x_var = x_patched.var(dim=-1, keepdim=True)
    x_patched_normalized = (x_patched - x_mean) / (
        x_var + 1.0e-6
    ) ** 0.5  # [1, 19*19, 3*32*32]

    # n_patches = preds.shape[0]
    n_patches_calc = int(np.ceil(N * mask_ratio))
    # assert n_patches == n_patches_calc, f"{n_patches} != {n_patches_calc}"
    # n_patches_mask = (1 - mask[0].float()).sum()
    # assert n_patches_mask == n_patches_calc, f"{n_patches_mask} != {n_patches_calc}"
    preds_new = preds.reshape(B, n_patches_calc, -1)
    labels_new = labels.reshape(B, n_patches_calc, -1)
    preds_with_gt = torch.zeros((N, 3 * mask_size * mask_size))
    labels_with_gt = torch.zeros((N, 3 * mask_size * mask_size))
    gt_with_mask = torch.zeros((N, 3 * mask_size * mask_size))
    zero_patch = torch.ones((3, mask_size, mask_size)).flatten() * color_zero_patch

    b = image_index_to_show
    i = 0
    for n in range(N):
        if mask[b, n] == 0:
            preds_with_gt[n, :] = preds_new[b, i, :]
            labels_with_gt[n, :] = labels_new[b, i, :]
            gt_with_mask[n, :] = zero_patch
            i += 1
        else:
            # insert gts here
            preds_with_gt[n, :] = x_patched_normalized[b, n, :]
            labels_with_gt[n, :] = x_patched_normalized[b, n, :]
            gt_with_mask[n, :] = x_patched[b, n, :]
    # un normalize
    labels_with_gt = labels_with_gt * (x_var + 1.0e-6) ** 0.5 + x_mean
    labels_with_gt = labels_with_gt - labels_with_gt.min()
    labels_with_gt = labels_with_gt / labels_with_gt.max()

    preds_with_gt = preds_with_gt * (x_var + 1.0e-6) ** 0.5 + x_mean
    preds_with_gt = preds_with_gt - preds_with_gt.min()
    preds_with_gt = preds_with_gt / preds_with_gt.max()
    preds_with_gt = preds_with_gt.reshape(H_strided, W_strided, 3, mask_size, mask_size)
    labels_with_gt = labels_with_gt.reshape(
        H_strided, W_strided, 3, mask_size, mask_size
    )
    gt_with_mask = gt_with_mask.reshape(H_strided, W_strided, 3, mask_size, mask_size)
    preds_with_gt = preds_with_gt.permute(0, 1, 3, 4, 2)
    labels_with_gt = labels_with_gt.permute(0, 1, 3, 4, 2)
    gt_with_mask = gt_with_mask.permute(
        0, 1, 3, 4, 2
    )  # [ H_strided, W_strided, mask_size, mask_size, 3]
    # build back together the image
    pred_img = torch.zeros((H, W, 3))
    label_img = torch.zeros((H, W, 3))
    gt_mask_img = torch.zeros((H, W, 3))
    for i in range(H_strided):
        for j in range(W_strided):
            pred_img[
                i * mask_size : (i + 1) * mask_size,
                j * mask_size : (j + 1) * mask_size,
                :,
            ] = preds_with_gt[i, j, :, :, :]
            label_img[
                i * mask_size : (i + 1) * mask_size,
                j * mask_size : (j + 1) * mask_size,
                :,
            ] = labels_with_gt[i, j, :, :, :]
            gt_mask_img[
                i * mask_size : (i + 1) * mask_size,
                j * mask_size : (j + 1) * mask_size,
                :,
            ] = gt_with_mask[i, j, :, :, :]
    img_total = torch.cat((x[0].permute(1, 2, 0), gt_mask_img, pred_img), dim=1).numpy()
    plt.imsave(path, img_total)

On the left is the gt, in the middle the gt with the masked tokens grey and on the right the gt with the predicted tokens where the mask was in the middle. the prediction on the right are with random weights. hiera_mae_output

Thank you very much for your help, and I would like to ask what should I do if I want to use the trained weights to visualize the results?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants