Skip to content

Commit

Permalink
Fix OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Mar 8, 2023
1 parent b30539d commit de69f6d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions magma/image_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def forward(
if self.use_layernorm:
logits = self.ln(logits)

if logits.ndim == 2:
logits = logits.unsqueeze(1)
# Added for shape mismatch. No longer needed?
#if logits.ndim == 2:
# logits = logits.unsqueeze(1)

return logits
3 changes: 2 additions & 1 deletion magma/magma.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def __init__(self, config, device=None, init_weights=True):
for param in self.image_prefix.enc.parameters():
param.requires_grad = False

self.lm.to(self.device)
# added for CPU tests. No longer needed and leads to OOM on GPUs.
#self.lm.to(self.device)

def add_adapters(
self,
Expand Down

0 comments on commit de69f6d

Please sign in to comment.