Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 7, 2025
1 parent 36005b2 commit 8c1b4e1
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pvnet/models/multimodal/fusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
if use_residual:
self.layer_norm = nn.LayerNorm(hidden_dim)


def compute_modality_weights(
self,
attended_features: torch.Tensor,
Expand All @@ -111,13 +112,16 @@ def compute_modality_weights(
weights = self.weight_network(attended_features)

if mask is not None:
# Reshape mask to match weights dimension
mask = mask.unsqueeze(-1)
weights = weights.masked_fill(~mask, 0.0)

# Normalise weights
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)

return weights



def forward(
self,
features: Dict[str, torch.Tensor],
Expand Down

0 comments on commit 8c1b4e1

Please sign in to comment.