diff --git a/pvnet/models/multimodal/fusion_blocks.py b/pvnet/models/multimodal/fusion_blocks.py index 3bc1573b..e7b4f844 100644 --- a/pvnet/models/multimodal/fusion_blocks.py +++ b/pvnet/models/multimodal/fusion_blocks.py @@ -96,6 +96,7 @@ def __init__( if use_residual: self.layer_norm = nn.LayerNorm(hidden_dim) + def compute_modality_weights( self, attended_features: torch.Tensor, @@ -111,13 +112,16 @@ def compute_modality_weights( weights = self.weight_network(attended_features) if mask is not None: + # Reshape mask to match weights dimension + mask = mask.unsqueeze(-1) weights = weights.masked_fill(~mask, 0.0) # Normalise weights weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9) return weights - + + def forward( self, features: Dict[str, torch.Tensor],