From 9f2cdaad079e21f505664a54275adab832c4eda1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Mon, 23 Dec 2024 09:07:22 +0000 Subject: [PATCH] Fix equality checking Signed-off-by: DarkLight1337 --- vllm/multimodal/inputs.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 61f70e4b0ecc7..c90e1cd22defc 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -108,6 +108,18 @@ class PlaceholderRange(TypedDict): Uses a list instead of a tensor if the dimensions of each element do not match. """ + +def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: + """Equality check between :data:`NestedTensors` objects.""" + if isinstance(a, torch.Tensor): + return isinstance(b, torch.Tensor) and bool((a == b).all().item()) + if isinstance(b, torch.Tensor): + return isinstance(a, torch.Tensor) and bool((b == a).all().item()) + + return (len(a) == len(b) + and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))) + + BatchedTensorInputs: TypeAlias = dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via @@ -122,6 +134,14 @@ class MultiModalFieldItem: modality: str data: NestedTensors + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + + a_tupl = (type(self.field), self.modality) + b_tupl = (type(other.field), other.modality) + return a_tupl == b_tupl and nested_tensors_equal(self.data, other.data) + class MultiModalField(ABC): """Represents a field in :class:`MultiModalKwargs`.""" @@ -356,6 +376,16 @@ def as_kwargs( return cast(BatchedTensorInputs, json_mapped) + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + if self._items_by_key != other._items_by_key: + return False + + ks = self.keys() + return (ks == other.keys() + and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + def get_item(self, key: str, item_index: int) -> MultiModalFieldItem: return self._items_by_key[key][item_index]