Skip to content

Commit

Permalink
Merge pull request #2 from Quentin-Anthony/shape-fix
Browse files Browse the repository at this point in the history
Fixed shape error by unqueezing logits
  • Loading branch information
Quentin-Anthony authored Mar 23, 2023
2 parents 27f81d6 + 4d8d4d7 commit d793893
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions magma/image_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def forward(
if self.use_layernorm:
logits = self.ln(logits)

# Added for shape mismatch. No longer needed?
#if logits.ndim == 2:
# logits = logits.unsqueeze(1)
# Added for shape mismatch.
if logits.ndim == 2:
logits = logits.unsqueeze(1)

return logits

0 comments on commit d793893

Please sign in to comment.