From 86ecc61a531b008b81554327b556d057838e53e4 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 25 Oct 2024 16:42:56 -0600 Subject: [PATCH] [Bugfix] Fix crash with llama 3.2 vision models and guided decoding (#9631) Signed-off-by: Travis Johnson Co-authored-by: pavlo-ruban Co-authored-by: Nick Hill Signed-off-by: Shanshan Wang --- .../guided_decoding/outlines_logits_processors.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index c28bd71c9f682..e1309c31f77e7 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -15,11 +15,11 @@ # limitations under the License. import copy import json -import math from collections import defaultdict from functools import lru_cache from typing import Callable, DefaultDict, Dict, List, Union +import numpy as np import torch from lark import Lark from outlines import grammars @@ -77,9 +77,17 @@ def __call__(self, input_ids: List[int], f"Unsupported instruction type {type(instruction)}") mask = torch.full((scores.shape[-1], ), - -math.inf, + -torch.inf, device=scores.device) - mask[allowed_tokens] = 0 + # The tokenizer may support more token ids than the model can generate, + # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256 + # but scores.shape == torch.Size([128256]) + # Using NumPy is faster for filtering token ids + allowed_tokens = np.array(allowed_tokens, dtype=np.int64) + allowed_tokens = torch.tensor(allowed_tokens, device=scores.device) + allowed_tokens = allowed_tokens.masked_select( + allowed_tokens < scores.shape[-1]) + mask.index_fill_(0, allowed_tokens, 0) scores.add_(mask) return scores