Skip to content

Commit

Permalink
Fixed shape error by unqueezing logits
Browse files Browse the repository at this point in the history
  • Loading branch information
Rabbidon authored Mar 20, 2023
1 parent de69f6d commit 4d8d4d7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions magma/image_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def forward(
if self.use_layernorm:
logits = self.ln(logits)

# Added for shape mismatch. No longer needed?
#if logits.ndim == 2:
# logits = logits.unsqueeze(1)
# Added for shape mismatch.
if logits.ndim == 2:
logits = logits.unsqueeze(1)

return logits

0 comments on commit 4d8d4d7

Please sign in to comment.