Skip to content

Commit

Permalink
Fix NoneType Error in flatten Function for Text-Only Tasks in LLAVA…
Browse files Browse the repository at this point in the history
… Models (#501)

* refine the flatten function for text only task

* update text only tinyllava

---------

Co-authored-by: Jinhe Bi <[email protected]>
Co-authored-by: Jinhe Bi <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent 721ee92 commit b1fbf55
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
7 changes: 5 additions & 2 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,13 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
return res

def flatten(self, input):
if not input or any(i is None for i in input):
return []
new_list = []
for i in input:
for j in i:
new_list.append(j)
if i:
for j in i:
new_list.append(j)
return new_list

def generate_until(self, requests: List[Instance]) -> List[str]:
Expand Down
7 changes: 5 additions & 2 deletions lmms_eval/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,13 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
return res

def flatten(self, input):
if not input or any(i is None for i in input):
return []
new_list = []
for i in input:
for j in i:
new_list.append(j)
if i:
for j in i:
new_list.append(j)
return new_list

def load_video(self, video_path, max_frames_num):
Expand Down
42 changes: 35 additions & 7 deletions lmms_eval/models/tinyllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,13 @@ def tok_decode(self, tokens):
return self.tokenizer.decode([tokens])

def flatten(self, input):
if not input or any(i is None for i in input):
return []
new_list = []
for i in input:
for j in i:
new_list.append(j)
if i:
for j in i:
new_list.append(j)
return new_list

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
Expand Down Expand Up @@ -234,17 +237,42 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:

msg = Message()
msg.add_message(prompts_input)

# Process text input and get input_ids
contxt_id = self._text_processor(msg.messages, mode="eval")["input_ids"]
# Add the answer of the second role

# Set the continuation as the second role's response
msg._messages[1]["value"] = continuation
input_ids = self._text_processor(msg.messages, mode="eval")["input_ids"]

# Prepare labels and ensure the correct shape
labels = input_ids.clone()
# Context part no need to calculate for loss
labels[0, : contxt_id.shape[1]] = -100
if labels.dim() == 1:
labels = labels.unsqueeze(0) # Convert to (1, seq_len) if needed

if len(contxt_id.shape) == 1:
contxt_id = contxt_id.unsqueeze(0) # Convert to (1, context_len)

# Mask the context part to ignore it in loss computation
labels[:, : contxt_id.shape[1]] = -100

# Move tensors to the correct device
device = self.device
input_ids = input_ids.to(device)
labels = labels.to(device)

if len(input_ids.shape) == 1:
input_ids = input_ids.unsqueeze(0) # Ensure it is (batch_size, seq_len)

# Handle image input if available
if image is None:
image_sizes = []
with torch.inference_mode():
outputs = self.model(input_ids=input_ids, labels=labels, use_cache=True)
else:
with torch.inference_mode():
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes)

with torch.inference_mode():
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes)
loss = outputs["loss"]
# loss = torch.exp(loss)
logits = outputs["logits"]
Expand Down

0 comments on commit b1fbf55

Please sign in to comment.