From 4d8d4d765758cf4ba46cdecaf3f94ccf45ee411e Mon Sep 17 00:00:00 2001 From: Edwin Fennell Date: Mon, 20 Mar 2023 13:32:43 +0000 Subject: [PATCH] Fixed shape error by unqueezing logits --- magma/image_prefix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/magma/image_prefix.py b/magma/image_prefix.py index 60d4847..11e723e 100644 --- a/magma/image_prefix.py +++ b/magma/image_prefix.py @@ -109,8 +109,8 @@ def forward( if self.use_layernorm: logits = self.ln(logits) - # Added for shape mismatch. No longer needed? - #if logits.ndim == 2: - # logits = logits.unsqueeze(1) + # Added for shape mismatch. + if logits.ndim == 2: + logits = logits.unsqueeze(1) return logits