Skip to content

Commit

Permalink
Initial
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 7, 2025
1 parent a3c2251 commit f400805
Show file tree
Hide file tree
Showing 5 changed files with 539 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,7 @@ dmypy.json

# Pyre type checker
.pyre/

# OS
.DS_Store

168 changes: 168 additions & 0 deletions pvnet/models/multimodal/attention_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# attention_blocks.py

"""
Attention blocks for multimodal dynamic fusion implementation
Fundamentally a foundation script, that defines the mechanisms MultiheadAttention, CrossModalAttention, and SelfAttention
Aformentioned attention blocks enable early cross modal interaction, with permits each modality to learn from features of other modalities as opposed to independent processing
"""


import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import Dict, Optional
from torch import nn


class AbstractAttentionBlock(nn.Module, ABC):
""" Abstract attention base class definition """

# Forward pass
@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
pass


# Splits input into multiple heads - scales attention scores for stability
class MultiheadAttention(AbstractAttentionBlock):
""" Multihead attention implementation / definition """

# Initialisation of multihead attention
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.1
):

super().__init__()
if embed_dim % num_heads != 0:
raise ValueError("embed_dim not divisible by num_heads")

self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5

self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)

self.dropout = nn.Dropout(dropout)

# Forward pass - multihead attention
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:

batch_size = query.shape[0]

# Projection and reshape - define attention scores
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

# Attention weights and subsequent output
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

return self.out_proj(attn_output)


# Enables singular modality to 'attend' to others utilising specific attention block
class CrossModalAttention(AbstractAttentionBlock):
""" CrossModal attention - interaction between multiple modalities """

# Initialisation of CrossModal attention
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.1,
num_modalities: int = 2
):

super().__init__()
self.num_modalities = num_modalities
self.attention_blocks = nn.ModuleList([
MultiheadAttention(embed_dim, num_heads, dropout=dropout)
for _ in range(num_modalities)
])
self.dropout = nn.Dropout(dropout)
self.layer_norms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_modalities)])

# Forward pass - CrossModal attention
def forward(
self,
modalities: Dict[str, torch.Tensor],
mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:

updated_modalities = {}
modality_keys = list(modalities.keys())

for i, key in enumerate(modality_keys):
query = modalities[key]
# Combine other modalities as key-value pairs
other_modalities = [modalities[k] for k in modality_keys if k != key]
if other_modalities:
key_value = torch.cat(other_modalities, dim=1)

# Apply attention block for this modality
attn_output = self.attention_blocks[i](query, key_value, key_value, mask)
attn_output = self.dropout(attn_output)
updated_modalities[key] = self.layer_norms[i](query + attn_output)
else:
# If no other modalities, pass through
updated_modalities[key] = query

return updated_modalities


# Permits each element in input sequence to attend all other elements
class SelfAttention(AbstractAttentionBlock):
""" SelfAttention block for singular modality """

# Initialisation of self attention
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.1
):

super().__init__()
self.attention = MultiheadAttention(embed_dim, num_heads, dropout)
self.layer_norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)

# Forward pass - self attention
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:

attn_output = self.attention(x, x, x, mask)
attn_output = self.dropout(attn_output)
return self.layer_norm(x + attn_output)

214 changes: 214 additions & 0 deletions pvnet/models/multimodal/fusion_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# fusion_blocks.py

"""
Fusion blocks for dynamic multimodal fusion implementation
Definition of foundational fusion mechanisms; DynamicFusionModule and ModalityGating
Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning
"""


import torch
import torch.nn.functional as F
from torch import nn
from abc import ABC, abstractmethod
from typing import Dict, Optional, Union, List

from pvnet.models.multimodal.attention_blocks import MultiheadAttention


class AbstractFusionBlock(nn.Module, ABC):
""" Abstract fusion base class definition """

# Forward pass
@abstractmethod
def forward(
self,
features: Dict[str, torch.Tensor]
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
pass


class DynamicFusionModule(AbstractFusionBlock):

""" Dynamic fusion implementation / definition """
def __init__(
self,
feature_dims: Dict[str, int],
hidden_dim: int = 256,
num_heads: int = 8,
dropout: float = 0.1,
fusion_method: str = "weighted_sum",
use_residual: bool = True
):

# Initialisation of dynamic fusion module
super().__init__()
self.feature_dims = feature_dims
self.hidden_dim = hidden_dim
self.fusion_method = fusion_method
self.use_residual = use_residual

if fusion_method not in ["weighted_sum", "concat"]:
raise ValueError(f"Invalid fusion method: {fusion_method}")

# Define projections for each modality
# Specified features only considered
self.projections = nn.ModuleDict({
name: nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
)
for name, dim in feature_dims.items()
if dim > 0
})

# Cross attention mechanism
self.cross_attention = MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout
)

# Dynamic weighting network
# Weight generation per modality - consistently positive
self.weight_network = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)

# Optional output projection for concatenation
if fusion_method == "concat":
self.output_projection = nn.Sequential(
nn.Linear(hidden_dim * len(feature_dims), hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
)

# Layer normalisation for residual connections
if use_residual:
self.layer_norm = nn.LayerNorm(hidden_dim)

def compute_modality_weights(
self,
attended_features: torch.Tensor,
available_modalities: List[str],
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:

# Computation of dynamic weights for available modalities
batch_size = attended_features.size(0)
num_modalities = len(available_modalities)

# Independent weight generation per modality
weights = self.weight_network(attended_features)

if mask is not None:
weights = weights.masked_fill(~mask, 0.0)

# Normalise weights
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)

return weights

def forward(
self,
features: Dict[str, torch.Tensor],
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:

# Forward pass for dynamic fusion
# Project each modality to common space
projected_features = {
name: self.projections[name](feat)
for name, feat in features.items()
if feat is not None and self.feature_dims[name] > 0
}

if not projected_features:
raise ValueError("Invalid features")

# Stack features for attention and store for residual connection
feature_stack = torch.stack(list(projected_features.values()), dim=1)
input_features = feature_stack

# Cross attention - application
attended_features = self.cross_attention(
feature_stack,
feature_stack,
feature_stack
)

# Apply dynamic weights
weights = self.compute_modality_weights(
attended_features,
list(projected_features.keys()),
mask
)

# Weighted sum or concatenation
if self.fusion_method == "weighted_sum":
weighted_features = attended_features * weights
fused_features = weighted_features.sum(dim=1)
else:
weighted_features = attended_features * weights
fused_features = self.output_projection(
weighted_features.view(weighted_features.size(0), -1)
)

# Apply residual connection
if self.use_residual:
residual = input_features.mean(dim=1)
fused_features = self.layer_norm(fused_features + residual)

return fused_features


class ModalityGating(AbstractFusionBlock):

""" Modality gating mechanism definition """
def __init__(
self,
feature_dims: Dict[str, int],
hidden_dim: int = 256,
dropout: float = 0.1
):
# Initialisation of modality gating module
super().__init__()
self.feature_dims = feature_dims

# Create gate networks for each modality
self.gate_networks = nn.ModuleDict({
name: nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
for name, dim in feature_dims.items()
if dim > 0
})

def forward(
self,
features: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:

# Forward pass for modality gating
gated_features = {}

# Gate value and subsequent application
for name, feat in features.items():
if feat is not None and self.feature_dims.get(name, 0) > 0:
gate = self.gate_networks[name](feat)
gated_features[name] = feat * gate

return gated_features
Loading

0 comments on commit f400805

Please sign in to comment.