-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a3c2251
commit f400805
Showing
5 changed files
with
539 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,3 +140,7 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# OS | ||
.DS_Store | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# attention_blocks.py | ||
|
||
""" | ||
Attention blocks for multimodal dynamic fusion implementation | ||
Fundamentally a foundation script, that defines the mechanisms MultiheadAttention, CrossModalAttention, and SelfAttention | ||
Aformentioned attention blocks enable early cross modal interaction, with permits each modality to learn from features of other modalities as opposed to independent processing | ||
""" | ||
|
||
|
||
import torch | ||
import torch.nn.functional as F | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, Optional | ||
from torch import nn | ||
|
||
|
||
class AbstractAttentionBlock(nn.Module, ABC): | ||
""" Abstract attention base class definition """ | ||
|
||
# Forward pass | ||
@abstractmethod | ||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
mask: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
pass | ||
|
||
|
||
# Splits input into multiple heads - scales attention scores for stability | ||
class MultiheadAttention(AbstractAttentionBlock): | ||
""" Multihead attention implementation / definition """ | ||
|
||
# Initialisation of multihead attention | ||
def __init__( | ||
self, | ||
embed_dim: int, | ||
num_heads: int, | ||
dropout: float = 0.1 | ||
): | ||
|
||
super().__init__() | ||
if embed_dim % num_heads != 0: | ||
raise ValueError("embed_dim not divisible by num_heads") | ||
|
||
self.embed_dim = embed_dim | ||
self.num_heads = num_heads | ||
self.head_dim = embed_dim // num_heads | ||
self.scale = self.head_dim ** -0.5 | ||
|
||
self.q_proj = nn.Linear(embed_dim, embed_dim) | ||
self.k_proj = nn.Linear(embed_dim, embed_dim) | ||
self.v_proj = nn.Linear(embed_dim, embed_dim) | ||
self.out_proj = nn.Linear(embed_dim, embed_dim) | ||
|
||
self.dropout = nn.Dropout(dropout) | ||
|
||
# Forward pass - multihead attention | ||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
mask: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
|
||
batch_size = query.shape[0] | ||
|
||
# Projection and reshape - define attention scores | ||
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) | ||
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) | ||
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) | ||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale | ||
|
||
if mask is not None: | ||
scores = scores.masked_fill(mask == 0, float('-inf')) | ||
|
||
# Attention weights and subsequent output | ||
attn_weights = F.softmax(scores, dim=-1) | ||
attn_weights = self.dropout(attn_weights) | ||
attn_output = torch.matmul(attn_weights, v) | ||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim) | ||
|
||
return self.out_proj(attn_output) | ||
|
||
|
||
# Enables singular modality to 'attend' to others utilising specific attention block | ||
class CrossModalAttention(AbstractAttentionBlock): | ||
""" CrossModal attention - interaction between multiple modalities """ | ||
|
||
# Initialisation of CrossModal attention | ||
def __init__( | ||
self, | ||
embed_dim: int, | ||
num_heads: int, | ||
dropout: float = 0.1, | ||
num_modalities: int = 2 | ||
): | ||
|
||
super().__init__() | ||
self.num_modalities = num_modalities | ||
self.attention_blocks = nn.ModuleList([ | ||
MultiheadAttention(embed_dim, num_heads, dropout=dropout) | ||
for _ in range(num_modalities) | ||
]) | ||
self.dropout = nn.Dropout(dropout) | ||
self.layer_norms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_modalities)]) | ||
|
||
# Forward pass - CrossModal attention | ||
def forward( | ||
self, | ||
modalities: Dict[str, torch.Tensor], | ||
mask: Optional[torch.Tensor] = None | ||
) -> Dict[str, torch.Tensor]: | ||
|
||
updated_modalities = {} | ||
modality_keys = list(modalities.keys()) | ||
|
||
for i, key in enumerate(modality_keys): | ||
query = modalities[key] | ||
# Combine other modalities as key-value pairs | ||
other_modalities = [modalities[k] for k in modality_keys if k != key] | ||
if other_modalities: | ||
key_value = torch.cat(other_modalities, dim=1) | ||
|
||
# Apply attention block for this modality | ||
attn_output = self.attention_blocks[i](query, key_value, key_value, mask) | ||
attn_output = self.dropout(attn_output) | ||
updated_modalities[key] = self.layer_norms[i](query + attn_output) | ||
else: | ||
# If no other modalities, pass through | ||
updated_modalities[key] = query | ||
|
||
return updated_modalities | ||
|
||
|
||
# Permits each element in input sequence to attend all other elements | ||
class SelfAttention(AbstractAttentionBlock): | ||
""" SelfAttention block for singular modality """ | ||
|
||
# Initialisation of self attention | ||
def __init__( | ||
self, | ||
embed_dim: int, | ||
num_heads: int, | ||
dropout: float = 0.1 | ||
): | ||
|
||
super().__init__() | ||
self.attention = MultiheadAttention(embed_dim, num_heads, dropout) | ||
self.layer_norm = nn.LayerNorm(embed_dim) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
# Forward pass - self attention | ||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
mask: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
|
||
attn_output = self.attention(x, x, x, mask) | ||
attn_output = self.dropout(attn_output) | ||
return self.layer_norm(x + attn_output) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# fusion_blocks.py | ||
|
||
""" | ||
Fusion blocks for dynamic multimodal fusion implementation | ||
Definition of foundational fusion mechanisms; DynamicFusionModule and ModalityGating | ||
Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning | ||
""" | ||
|
||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, Optional, Union, List | ||
|
||
from pvnet.models.multimodal.attention_blocks import MultiheadAttention | ||
|
||
|
||
class AbstractFusionBlock(nn.Module, ABC): | ||
""" Abstract fusion base class definition """ | ||
|
||
# Forward pass | ||
@abstractmethod | ||
def forward( | ||
self, | ||
features: Dict[str, torch.Tensor] | ||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: | ||
pass | ||
|
||
|
||
class DynamicFusionModule(AbstractFusionBlock): | ||
|
||
""" Dynamic fusion implementation / definition """ | ||
def __init__( | ||
self, | ||
feature_dims: Dict[str, int], | ||
hidden_dim: int = 256, | ||
num_heads: int = 8, | ||
dropout: float = 0.1, | ||
fusion_method: str = "weighted_sum", | ||
use_residual: bool = True | ||
): | ||
|
||
# Initialisation of dynamic fusion module | ||
super().__init__() | ||
self.feature_dims = feature_dims | ||
self.hidden_dim = hidden_dim | ||
self.fusion_method = fusion_method | ||
self.use_residual = use_residual | ||
|
||
if fusion_method not in ["weighted_sum", "concat"]: | ||
raise ValueError(f"Invalid fusion method: {fusion_method}") | ||
|
||
# Define projections for each modality | ||
# Specified features only considered | ||
self.projections = nn.ModuleDict({ | ||
name: nn.Sequential( | ||
nn.Linear(dim, hidden_dim), | ||
nn.LayerNorm(hidden_dim), | ||
nn.ReLU(), | ||
nn.Dropout(dropout) | ||
) | ||
for name, dim in feature_dims.items() | ||
if dim > 0 | ||
}) | ||
|
||
# Cross attention mechanism | ||
self.cross_attention = MultiheadAttention( | ||
embed_dim=hidden_dim, | ||
num_heads=num_heads, | ||
dropout=dropout | ||
) | ||
|
||
# Dynamic weighting network | ||
# Weight generation per modality - consistently positive | ||
self.weight_network = nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim // 2), | ||
nn.ReLU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(hidden_dim // 2, 1), | ||
nn.Sigmoid() | ||
) | ||
|
||
# Optional output projection for concatenation | ||
if fusion_method == "concat": | ||
self.output_projection = nn.Sequential( | ||
nn.Linear(hidden_dim * len(feature_dims), hidden_dim), | ||
nn.LayerNorm(hidden_dim), | ||
nn.ReLU(), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
# Layer normalisation for residual connections | ||
if use_residual: | ||
self.layer_norm = nn.LayerNorm(hidden_dim) | ||
|
||
def compute_modality_weights( | ||
self, | ||
attended_features: torch.Tensor, | ||
available_modalities: List[str], | ||
mask: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
|
||
# Computation of dynamic weights for available modalities | ||
batch_size = attended_features.size(0) | ||
num_modalities = len(available_modalities) | ||
|
||
# Independent weight generation per modality | ||
weights = self.weight_network(attended_features) | ||
|
||
if mask is not None: | ||
weights = weights.masked_fill(~mask, 0.0) | ||
|
||
# Normalise weights | ||
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9) | ||
|
||
return weights | ||
|
||
def forward( | ||
self, | ||
features: Dict[str, torch.Tensor], | ||
mask: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
|
||
# Forward pass for dynamic fusion | ||
# Project each modality to common space | ||
projected_features = { | ||
name: self.projections[name](feat) | ||
for name, feat in features.items() | ||
if feat is not None and self.feature_dims[name] > 0 | ||
} | ||
|
||
if not projected_features: | ||
raise ValueError("Invalid features") | ||
|
||
# Stack features for attention and store for residual connection | ||
feature_stack = torch.stack(list(projected_features.values()), dim=1) | ||
input_features = feature_stack | ||
|
||
# Cross attention - application | ||
attended_features = self.cross_attention( | ||
feature_stack, | ||
feature_stack, | ||
feature_stack | ||
) | ||
|
||
# Apply dynamic weights | ||
weights = self.compute_modality_weights( | ||
attended_features, | ||
list(projected_features.keys()), | ||
mask | ||
) | ||
|
||
# Weighted sum or concatenation | ||
if self.fusion_method == "weighted_sum": | ||
weighted_features = attended_features * weights | ||
fused_features = weighted_features.sum(dim=1) | ||
else: | ||
weighted_features = attended_features * weights | ||
fused_features = self.output_projection( | ||
weighted_features.view(weighted_features.size(0), -1) | ||
) | ||
|
||
# Apply residual connection | ||
if self.use_residual: | ||
residual = input_features.mean(dim=1) | ||
fused_features = self.layer_norm(fused_features + residual) | ||
|
||
return fused_features | ||
|
||
|
||
class ModalityGating(AbstractFusionBlock): | ||
|
||
""" Modality gating mechanism definition """ | ||
def __init__( | ||
self, | ||
feature_dims: Dict[str, int], | ||
hidden_dim: int = 256, | ||
dropout: float = 0.1 | ||
): | ||
# Initialisation of modality gating module | ||
super().__init__() | ||
self.feature_dims = feature_dims | ||
|
||
# Create gate networks for each modality | ||
self.gate_networks = nn.ModuleDict({ | ||
name: nn.Sequential( | ||
nn.Linear(dim, hidden_dim), | ||
nn.ReLU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(hidden_dim, 1), | ||
nn.Sigmoid() | ||
) | ||
for name, dim in feature_dims.items() | ||
if dim > 0 | ||
}) | ||
|
||
def forward( | ||
self, | ||
features: Dict[str, torch.Tensor] | ||
) -> Dict[str, torch.Tensor]: | ||
|
||
# Forward pass for modality gating | ||
gated_features = {} | ||
|
||
# Gate value and subsequent application | ||
for name, feat in features.items(): | ||
if feat is not None and self.feature_dims.get(name, 0) > 0: | ||
gate = self.gate_networks[name](feat) | ||
gated_features[name] = feat * gate | ||
|
||
return gated_features |
Oops, something went wrong.