diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 2de73cf5..5e661875 100755 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -278,10 +278,13 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: return res def flatten(self, input): + if not input or any(i is None for i in input): + return [] new_list = [] for i in input: - for j in i: - new_list.append(j) + if i: + for j in i: + new_list.append(j) return new_list def generate_until(self, requests: List[Instance]) -> List[str]: diff --git a/lmms_eval/models/llava_onevision.py b/lmms_eval/models/llava_onevision.py index 685a5c48..e137b5b6 100644 --- a/lmms_eval/models/llava_onevision.py +++ b/lmms_eval/models/llava_onevision.py @@ -362,10 +362,13 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: return res def flatten(self, input): + if not input or any(i is None for i in input): + return [] new_list = [] for i in input: - for j in i: - new_list.append(j) + if i: + for j in i: + new_list.append(j) return new_list def load_video(self, video_path, max_frames_num): diff --git a/lmms_eval/models/tinyllava.py b/lmms_eval/models/tinyllava.py index 2aaa4b19..d5dca8e5 100755 --- a/lmms_eval/models/tinyllava.py +++ b/lmms_eval/models/tinyllava.py @@ -186,10 +186,13 @@ def tok_decode(self, tokens): return self.tokenizer.decode([tokens]) def flatten(self, input): + if not input or any(i is None for i in input): + return [] new_list = [] for i in input: - for j in i: - new_list.append(j) + if i: + for j in i: + new_list.append(j) return new_list def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: @@ -234,17 +237,42 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: msg = Message() msg.add_message(prompts_input) + + # Process text input and get input_ids contxt_id = self._text_processor(msg.messages, mode="eval")["input_ids"] - # Add the answer of the second role + + # Set the continuation as the second role's response msg._messages[1]["value"] = continuation input_ids = self._text_processor(msg.messages, mode="eval")["input_ids"] + # Prepare labels and ensure the correct shape labels = input_ids.clone() - # Context part no need to calculate for loss - labels[0, : contxt_id.shape[1]] = -100 + if labels.dim() == 1: + labels = labels.unsqueeze(0) # Convert to (1, seq_len) if needed + + if len(contxt_id.shape) == 1: + contxt_id = contxt_id.unsqueeze(0) # Convert to (1, context_len) + + # Mask the context part to ignore it in loss computation + labels[:, : contxt_id.shape[1]] = -100 + + # Move tensors to the correct device + device = self.device + input_ids = input_ids.to(device) + labels = labels.to(device) + + if len(input_ids.shape) == 1: + input_ids = input_ids.unsqueeze(0) # Ensure it is (batch_size, seq_len) + + # Handle image input if available + if image is None: + image_sizes = [] + with torch.inference_mode(): + outputs = self.model(input_ids=input_ids, labels=labels, use_cache=True) + else: + with torch.inference_mode(): + outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes) - with torch.inference_mode(): - outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes) loss = outputs["loss"] # loss = torch.exp(loss) logits = outputs["logits"]