Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of More Pooling Methods #2048

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions timm/layers/gen_maxpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn

class GeneralizedMP(nn.Module):
"""
Implements Generalized Max Pooling (GMP), a global pooling operation that
generalizes the concept of max pooling to capture more complex and discriminative
features from the input tensor.
The class operates by computing a linear kernel based on the input tensor,
then solving a linear system to obtain the pooling coefficients. These coefficients
are used to weigh and aggregate the input features, resulting in a pooled feature vector.
Parameters:
lamb (float, optional): A regularization parameter used in the linear system
to ensure numerical stability. Default value is 1e3.
Note:
- The input tensor is expected to be in the format (B, D, H, W), where B is batch size,
D is depth (channels), H is height, and W is width.
- The implementation uses PyTorch's linear algebra functions to solve the linear system.
"""
def __init__(self, lamb = 1e3):
super().__init__()
self.lamb = nn.Parameter(lamb * torch.ones(1))
#self.inv_lamb = nn.Parameter((1./lamb) * torch.ones(1))

def forward(self, x):
B, D, H, W = x.shape
N = H * W
identity = torch.eye(N).cuda()
# reshape x, s.t. we can use the gmp formulation as a global pooling operation
x = x.view(B, D, N)
x = x.permute(0, 2, 1)
# compute the linear kernel
K = torch.bmm(x, x.permute(0, 2, 1))
# solve the linear system (K + lambda * I) * alpha = ones
A = K + self.lamb * identity
o = torch.ones(B, N, 1).cuda()
#alphas, _ = torch.gesv(o, A) # tested using pytorch 1.0.1
alphas = torch.linalg.solve(A,o) # TODO check it again
alphas = alphas.view(B, 1, -1)
xi = torch.bmm(alphas, x)
xi = xi.view(B, -1)
return xi
97 changes: 97 additions & 0 deletions timm/layers/how_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class HOWPooling(nn.Module):
"""
Implements HOW, as described in the paper
'Learning and Aggregating Deep Local Descriptors for Instance-Level Recognition'.
This pooling method focuses on aggregating deep local descriptors
for enhanced instance-level recognition.
The class includes functions for L2-based attention, smoothing average pooling,
L2 normalization (l2n), and a forward method that integrates these components.
It applies dimensionality reduction to the input features before the pooling operation.
Parameters:
input_dim (int): Dimension of the input features.
dim_reduction (int): Target dimension after reduction.
kernel_size (int): Size of the kernel used in smoothing average pooling.
"""
def __init__(self, input_dim = 512, dim_reduction = 128, kernel_size = 3):
super(HOWPooling, self).__init__()
self.kernel_size = kernel_size
self.dimreduction = ConvDimReduction(input_dim, dim_reduction)

def L2Attention(self, x):
return (x.pow(2.0).sum(1) + 1e-10).sqrt().squeeze(0)

def smoothing_avg_pooling(self, feats):
"""Smoothing average pooling
:param torch.Tensor feats: Feature map
:param int kernel_size: kernel size of pooling
:return torch.Tensor: Smoothend feature map
"""
pad = self.kernel_size // 2
return F.avg_pool2d(feats, (self.kernel_size, self.kernel_size), stride=1, padding=pad,
count_include_pad=False)

def l2n(self, x, eps=1e-6):
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)

def forward(self, x):

weights = self.L2Attention(x)
x = self.smoothing_avg_pooling(x)
x = self.dimreduction(x)
x = (x * weights.unsqueeze(1)).sum((-2, -1))
return self.l2n(x)

class ConvDimReduction(nn.Conv2d):
"""
Implements dimensionality reduction using a convolutional layer. This layer is
designed for reducing the dimensions of input features, particularly for use in
aggregation and pooling operations like in the HOWPooling class.
The class also includes methods for learning and applying PCA whitening with shrinkage,
which is a technique to reduce dimensionality while preserving important feature variations.
Parameters:
input_dim (int): The input dimension (number of channels) of the network.
dim (int): The target output dimension for the whitening process.
"""
def __init__(self, input_dim, dim):
super().__init__(input_dim, dim, (1, 1), padding=0, bias=True)

def pcawhitenlearn_shrinkage(X, s=1.0):
"""Learn PCA whitening with shrinkage from given descriptors"""
N = X.shape[0]

# Learning PCA w/o annotations
m = X.mean(axis=0, keepdims=True)
Xc = X - m
Xcov = np.dot(Xc.T, Xc)
Xcov = (Xcov + Xcov.T) / (2*N)
eigval, eigvec = np.linalg.eig(Xcov)
order = eigval.argsort()[::-1]
eigval = eigval[order]
eigvec = eigvec[:, order]

eigval = np.clip(eigval, a_min=1e-14, a_max=None)
P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5*s))), eigvec.T)

return m, P.T

def initialize_pca_whitening(self, des):
"""Initialize PCA whitening from given descriptors. Return tuple of shift and projection."""
m, P = self.pcawhitenlearn_shrinkage(des)
m, P = m.T, P.T

projection = torch.Tensor(P[:self.weight.shape[0], :]).unsqueeze(-1).unsqueeze(-1)
self.weight.data = projection.to(self.weight.device)

projected_shift = -torch.mm(torch.FloatTensor(P), torch.FloatTensor(m)).squeeze()
self.bias.data = projected_shift[:self.weight.shape[0]].to(self.bias.device)
return m.T, P.T
36 changes: 36 additions & 0 deletions timm/layers/lse_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSEPool(nn.Module):
"""
Implements LogSumExp (LSE) pooling, an advanced pooling technique that provides
a smooth approximation to the max pooling operation. This pooling method is useful
for capturing the global distribution of features across spatial dimensions (height and width)
of the input tensor, while still maintaining differentiability.
The class supports learnable pooling behavior with an optional learnable parameter 'r'.
When 'r' is large, LSE pooling closely approximates max pooling, and when 'r' is small,
it behaves more like average pooling. The 'r' parameter can either be a fixed value or
learned during training.
Parameters:
r (float, optional): The initial value of the pooling parameter. Default is 10.
learnable (bool, optional): If True, 'r' is a learnable parameter. Default is True.
"""

def __init__(self, r=10, learnable=True):
super(LSEPool, self).__init__()
if learnable:
self.r = nn.Parameter(torch.ones(1) * r)
else:
self.r = r

def forward(self, x):
s = (x.size(2) * x.size(3))
x_max = F.adaptive_max_pool2d(x, 1)
exp = torch.exp(self.r * (x - x_max))
sumexp = 1 / s * torch.sum(exp, dim=(2, 3))
sumexp = sumexp.view(sumexp.size(0), -1, 1, 1)
logsumexp = x_max + 1 / self.r * torch.log(sumexp)
return logsumexp
101 changes: 101 additions & 0 deletions timm/layers/simpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn

class SimPool(nn.Module):
"""
Implements SimPool as described in the ICCV 2023 paper
"Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?".
This class is designed to provide an efficient and effective pooling strategy
for both Transformer and CNN architectures.
SimPool applies a global average pooling (GAP) operation as an initial step
and then utilizes a simple but powerful attention mechanism to refine the pooled features.
The attention mechanism uses linear transformations for queries and keys, followed by
softmax normalization to compute attention scores.
Parameters:
dim (int): Dimension of the input features.
num_heads (int, optional): Number of attention heads. Default is 1.
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, value projections. Default is False.
qk_scale (float, optional): Scaling factor for query-key dot product. Default is None, which uses the inverse square root of head dimensions.
gamma (float, optional): Scaling parameter for value vectors, used if not None. Default is None.
use_beta (bool, optional): If True, adds a learnable translation to the value vectors after applying gamma. Default is False.
"""
def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5

self.norm_patches = nn.LayerNorm(dim, eps=1e-6)

self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)

if gamma is not None:
self.gamma = torch.tensor([gamma], device='cuda')
if use_beta:
self.beta = nn.Parameter(torch.tensor([0.0], device='cuda'))
self.eps = torch.tensor([1e-6], device='cuda')

self.gamma = gamma
self.use_beta = use_beta

def prepare_input(self, x):
if len(x.shape) == 3: # Transformer
# Input tensor dimensions:
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
B, N, d = x.shape
gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
return gap_cls, x
if len(x.shape) == 4: # CNN
# Input tensor dimensions:
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
B, d, H, W = x.shape
gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
return gap_cls, x
else:
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")

def forward(self, x):
# Prepare input tensor and perform GAP as initialization
gap_cls, x = self.prepare_input(x)

# Prepare queries (q), keys (k), and values (v)
q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)

# Extract dimensions after normalization
Bq, Nq, dq = q.shape
Bk, Nk, dk = k.shape
Bv, Nv, dv = v.shape

# Check dimension consistency across batches and channels
assert Bq == Bk == Bv
assert dq == dk == dv

# Apply linear transformation for queries and keys then reshape
qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1, 3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1, 3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)

vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1, 3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)

# Compute attention scores
attn = (qq @ kk.transpose(-2, -1)) * self.scale
# Apply softmax for normalization
attn = attn.softmax(dim=-1)

# If gamma scaling is used
if self.gamma is not None:
# Apply gamma scaling on values and compute the weighted sum using attention scores
x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma), 1/self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
# If use_beta, add a learnable translation
if self.use_beta:
x = x + self.beta
else:
# Compute the weighted sum using attention scores
x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)

return x.squeeze()
119 changes: 119 additions & 0 deletions timm/layers/slot_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
import torch.nn as nn
from torch.nn import init

def prepare_input(self, x):
"""
Prepares the input tensor for different neural network architectures (Transformers and CNNs).
This function adjusts the shape of the input tensor based on its dimensionality.
It supports input tensors for Transformers (3D) and CNNs (4D), ensuring they are
correctly formatted for these architectures.
For a Transformer, it expects a tensor of shape (B, N, d), where B is the batch size,
N are patch tokens, and d is the depth (channels). The tensor is returned as is.
For a CNN, it expects a tensor of shape (B, d, H, W), where B is the batch size,
d is the depth (channels), H is the height, and W is the width. The tensor is reshaped
and permuted to the shape (B, H*W, d) to match CNN input requirements.
Parameters:
x (torch.Tensor): The input tensor to be preprocessed.
"""
if len(x.shape) == 3: # Transformer
# Input tensor dimensions:
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
B, N, d = x.shape
return x
if len(x.shape) == 4: # CNN
# Input tensor dimensions:
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
B, d, H, W = x.shape
x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
return x
else:
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")

class SlotPooling(nn.Module):
"""
This class implements the Slot Attention module as described in the paper
"Object-Centric Learning with Slot Attention".
The module is designed for object-centric learning tasks and utilizes the concept of
'slots' to represent distinct object features within an input.
It iteratively refines these slots through a pooling mechanism to capture
complex object representations.
Parameters:
num_slots (int): Number of slots to be used.
dim (int): Dimensionality of the input features.
iters (int, optional): Number of iterations for slot refinement. Default is 3.
eps (float, optional): A small epsilon value to avoid division by zero. Default is 1e-8.
hidden_dim (int, optional): Dimensionality of the hidden layer within the module. Default is 128.
"""
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
super().__init__()
self.num_slots = num_slots
self.iters = iters
self.eps = eps
self.scale = dim ** -0.5

self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))

self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
init.xavier_uniform_(self.slots_logsigma)

self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(dim, dim)
self.to_v = nn.Linear(dim, dim)

self.gru = nn.GRUCell(dim, dim)

hidden_dim = max(dim, hidden_dim)

self.mlp = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(inplace = True),
nn.Linear(hidden_dim, dim)
)

self.norm_input = nn.LayerNorm(dim)
self.norm_slots = nn.LayerNorm(dim)
self.norm_pre_ff = nn.LayerNorm(dim)

def forward(self, inputs, num_slots = None):
inputs = prepare_input(inputs)
b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
n_s = num_slots if num_slots is not None else self.num_slots

mu = self.slots_mu.expand(b, n_s, -1)
sigma = self.slots_logsigma.exp().expand(b, n_s, -1)

slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)

inputs = self.norm_input(inputs)
k, v = self.to_k(inputs), self.to_v(inputs)

for _ in range(self.iters):
slots_prev = slots

slots = self.norm_slots(slots)
q = self.to_q(slots)

dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
attn = dots.softmax(dim=1) + self.eps

attn = attn / attn.sum(dim=-1, keepdim=True)

updates = torch.einsum('bjd,bij->bid', v, attn)

slots = self.gru(
updates.reshape(-1, d),
slots_prev.reshape(-1, d)
)

slots = slots.reshape(b, -1, d)
slots = slots + self.mlp(self.norm_pre_ff(slots))
slots = slots.max(dim=1)[0]

return slots
142 changes: 142 additions & 0 deletions timm/layers/vit_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Tuple

def prepare_input(self, x):
"""
Prepares the input tensor for different neural network architectures (Transformers and CNNs).
This function adjusts the shape of the input tensor based on its dimensionality.
It supports input tensors for Transformers (3D) and CNNs (4D), ensuring they are
correctly formatted for these architectures.
For a Transformer, it expects a tensor of shape (B, N, d), where B is the batch size,
N are patch tokens, and d is the depth (channels). The tensor is returned as is.
For a CNN, it expects a tensor of shape (B, d, H, W), where B is the batch size,
d is the depth (channels), H is the height, and W is the width. The tensor is reshaped
and permuted to the shape (B, H*W, d) to match CNN input requirements.
Parameters:
x (torch.Tensor): The input tensor to be preprocessed.
"""
if len(x.shape) == 3: # Transformer
# Input tensor dimensions:
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
B, N, d = x.shape
return x
if len(x.shape) == 4: # CNN
# Input tensor dimensions:
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
B, d, H, W = x.shape
x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
return x
else:
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")

class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
Compute the dot products of the query with all keys, divide each by sqrt(dim),
and apply a softmax function to obtain the weights on the values
Args: dim, mask
dim (int): dimention of attention
mask (torch.Tensor): tensor containing indices to be masked
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
- **mask** (-): tensor containing indices to be masked
Returns: context, attn
- **context**: tensor containing the context vector from attention mechanism.
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, dim: int):
super(ScaledDotProductAttention, self).__init__()
self.sqrt_dim = np.sqrt(dim)

def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim

if mask is not None:
score.masked_fill_(mask.view(score.size()), -float('Inf'))

attn = F.softmax(score, -1)
context = torch.bmm(attn, value)
return context, attn

class ViTPooling(nn.Module):
"""
Multi-Head Attention proposed in "Attention Is All You Need"
Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
These are concatenated and once again projected, resulting in the final values.
Multi-head attention allows the model to jointly attend to information from different representation
subspaces at different positions.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
Args:
d_model (int): The dimension of keys / values / quries (default: 512)
num_heads (int): The number of attention heads. (default: 8)
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): In transformer, three different ways:
Case 1: come from previoys decoder layer
Case 2: come from the input embedding
Case 3: come from the output embedding (masked)
- **key** (batch, k_len, d_model): In transformer, three different ways:
Case 1: come from the output of the encoder
Case 2: come from the input embeddings
Case 3: come from the output embedding (masked)
- **value** (batch, v_len, d_model): In transformer, three different ways:
Case 1: come from the output of the encoder
Case 2: come from the input embeddings
Case 3: come from the output embedding (masked)
- **mask** (-): tensor containing indices to be masked
Returns: output, attn
- **output** (batch, output_len, dimensions): tensor containing the attended output features.
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, d_model: int = 512, num_heads: int = 8):
super(ViTPooling, self).__init__()

assert d_model % num_heads == 0, "d_model % num_heads should be zero."

self.d_head = int(d_model / num_heads)
self.num_heads = num_heads
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))

def forward(
self,
x: Tensor,
mask: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:

x = prepare_input(x)
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)

query = self.query_proj(x).view(B, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
key = self.key_proj(x).view(B, -1, self.num_heads, self.d_head) # BxK_LENxNxD
value = self.value_proj(x).view(B, -1, self.num_heads, self.d_head) # BxV_LENxNxD

query = query.permute(2, 0, 1, 3).contiguous().view(B * self.num_heads, -1, self.d_head) # BNxQ_LENxD
key = key.permute(2, 0, 1, 3).contiguous().view(B * self.num_heads, -1, self.d_head) # BNxK_LENxD
value = value.permute(2, 0, 1, 3).contiguous().view(B * self.num_heads, -1, self.d_head) # BNxV_LENxD

if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN

context, attn = self.scaled_dot_attn(query, key, value, mask)

context = context.view(self.num_heads, B, -1, self.d_head)
context = context.permute(1, 2, 0, 3).contiguous().view(B, -1, self.num_heads * self.d_head) # BxTxND

return context[:, 0]