diff --git a/magma/image_prefix.py b/magma/image_prefix.py index 6c5ebff..60d4847 100644 --- a/magma/image_prefix.py +++ b/magma/image_prefix.py @@ -109,7 +109,8 @@ def forward( if self.use_layernorm: logits = self.ln(logits) - if logits.ndim == 2: - logits = logits.unsqueeze(1) + # Added for shape mismatch. No longer needed? + #if logits.ndim == 2: + # logits = logits.unsqueeze(1) return logits diff --git a/magma/magma.py b/magma/magma.py index 7c7b826..a2df2a2 100644 --- a/magma/magma.py +++ b/magma/magma.py @@ -106,7 +106,8 @@ def __init__(self, config, device=None, init_weights=True): for param in self.image_prefix.enc.parameters(): param.requires_grad = False - self.lm.to(self.device) + # added for CPU tests. No longer needed and leads to OOM on GPUs. + #self.lm.to(self.device) def add_adapters( self,