Skip to content

Commit

Permalink
Fix equality checking
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 23, 2024
1 parent ab7e84b commit 9f2cdaa
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ class PlaceholderRange(TypedDict):
Uses a list instead of a tensor if the dimensions of each element do not match.
"""


def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
"""Equality check between :data:`NestedTensors` objects."""
if isinstance(a, torch.Tensor):
return isinstance(b, torch.Tensor) and bool((a == b).all().item())
if isinstance(b, torch.Tensor):
return isinstance(a, torch.Tensor) and bool((b == a).all().item())

return (len(a) == len(b)
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)))


BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
Expand All @@ -122,6 +134,14 @@ class MultiModalFieldItem:
modality: str
data: NestedTensors

def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False

a_tupl = (type(self.field), self.modality)
b_tupl = (type(other.field), other.modality)
return a_tupl == b_tupl and nested_tensors_equal(self.data, other.data)


class MultiModalField(ABC):
"""Represents a field in :class:`MultiModalKwargs`."""
Expand Down Expand Up @@ -356,6 +376,16 @@ def as_kwargs(

return cast(BatchedTensorInputs, json_mapped)

def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if self._items_by_key != other._items_by_key:
return False

ks = self.keys()
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))

def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
return self._items_by_key[key][item_index]

Expand Down

0 comments on commit 9f2cdaa

Please sign in to comment.